Java UDF with pyspark
0. Problem
When using spark it’s always advised to use the spark functions. However sometimes we need to use more advanced functions or we need external libraries.
As an example we might want to use H3 library to work with geospatial data.
The library is implemented in multiple languages:
In this post we will explore different options for using this function with spark.
1. Creating a python UDF
The first option is to create a python UDF using the h3 python library. Before creating the UDF we can try the library with:
import h3
h3.geo_to_h3(0, 0, 8)
> ‘88754e6499fffff’
To create the UDF we only need to use the udf
decorator and specify the output type.
It’s also a good idea to handle cases where latitude
or longitude
might be null
/None
.
import pyspark.sql.functions as F
import pyspark.sql.types as T
@F.udf(T.StringType())
def geo_to_h3_py(latitude, longitude, resolution):
if latitude is None or longitude is None:
return None
return h3.geo_to_h3(latitude, longitude, resolution)
Once it’s defined it can be called with:
(
sdf
.withColumn("h8", geo_to_h3_py("latitude", "longitude", F.lit(8)))
.show()
)
2. Use a Java function with spark
2.1. Use H3 directly
The first thing needed to work with the Java H3 library is to download the jar
.
It can be downloaded from Maven/h3.
Once you have the jar
file you need to add it to $JAVA_HOME/jars
folder so that spark can use it without extra configuration.
Then you can directly access the java code and create the h3
instance and then call the geoToH3Address
function.
h3 = spark.sparkContext._jvm.com.uber.h3core.H3Core.newInstance()
h3.geoToH3Address(0.0, 0.0, 8)
> ‘88754e6499fffff’
However it is not possible to create a UDF using the java instance. The problem is that it’s a java object so it can’t be pickled.
2.2. Create a java callable function
Let’s create a simple java static
function that we can use without needing to create an instance.
We can create the file SimpleH3.java
inside the path /com/villoro/
(so that it’s better organized).
Remember to define the package
so that it matches the path.
2.2.1 Create the java file
/com/villoro/SimpleH3.java
package com.villoro;
import java.io.IOException;
import com.uber.h3core.H3Core;
class SimpleH3
{
private static H3Core h3;
public static String toH3Address(Double longitude, Double latitude, int resolution){
// Lazy instantiation
if (h3 == null) {
try {
h3 = H3Core.newInstance();
}
catch(IOException e) {
return null;
}
}
// Check that coordinates are
if (longitude == null || latitude == null) {
return null;
} else {
return h3.geoToH3Address(longitude, latitude, resolution);
}
}
}
The only different thing from the python implementation is that we are doing a lazy instantiation.
That means that the first time the function is run it will create the h3
instance.
But the next times it will reuse the same instance.
Lazy instantiation will make the code faster.
2.2.2. Compile the code
Once we have our java
code we need to compile it.
Since we are importing things from com.uber.h3core.H3Core
we need that jar
.
You can copy the jar
you downloaded and save it as jars/h3.jar
.
Then you can compile the code with:
javac -cp "jars/h3.jar" com\villoro\*.java
2.2.3. Create the jar
To create the jar
we can simple specify that we want to create the SimpleH3.jar
with all .class
files from com\villoro\
with:
jar cvf SimpleH3.jar com\villoro\*.class
Once you created the jar
, remeber to copy it to $JAVA_HOME/jars
.
2.2.4. Call the static java function with spark
We can do something similar to call the new static function:
# Using our static function
spark.sparkContext._jvm.com.villoro.SimpleH3.toH3Address(0.0, 0.0, 8)
However this can’t still be pickled since it’s a java class. So we can’t create a UDF with python.
3. Create a Java UDF
The solution will be to create the UDF with java.
The first thing you will need is to get the spark-sql.jar
from $JAVA_HOME/jars
and copy it to com\villoro\
.
Without this jar
you won’t be able to compile the UDF with java.
3.1. Create the java function
Then to create the UDF you need to implement one of the UDFX
class.
Since it will have three imputs we will use UDF3
.
We will need to declare all the input types (Double
, Double
and Integer
) as well as the ouput (String
).
So it will extend UDF3<Double, Double, Integer, String>
.
The rest of the code it’s really similar to the previous function.
/com/villoro/toH3AddressUDF.java
package com.villoro;
import java.io.IOException;
import com.uber.h3core.H3Core;
import org.apache.spark.sql.api.java.UDF3;
public class toH3AddressUDF implements UDF3<Double, Double, Integer, String> {
private H3Core h3;
@Override
public String call(Double longitude, Double latitude, Integer resolution) throws Exception {
// Lazy instantiation
if (h3 == null) {
try {
h3 = H3Core.newInstance();
}
catch(IOException e) {
return null;
}
}
// Check that coordinates are
if (longitude == null || latitude == null) {
return null;
} else {
return h3.geoToH3Address(longitude, latitude, resolution);
}
}
}
3.2. Compile the code
The command is really similar to the one we used before:
javac -cp "jars/h3.jar;jars/spark-sql.jar" com\villoro\*.java
Make sure the name of the jars
you are importing match.
3.3. Create the jar
We can use exactly the same command as before:
jar cvf SimpleH3.jar com\villoro\*.class
Once you created the jar
, remeber to overwrite it in $JAVA_HOME/jars
.
3.4. Call the java UDF with spark
Before calling the function we need to register it with:
spark.udf.registerJavaFunction("geo_to_h3", "com.villoro.toH3AddressUDF", T.StringType())
Then we can call it using F.expr
:
(
sdf
.withColumn("h8", F.expr("geo_to_h3(latitude, longitude, 8)"))
.show()
)
And we can also use it with SQL:
spark.sql("""
SELECT
latitude, longitude, geo_to_h3(latitude, longitude, 8)
FROM my_table
""").show()
4. Performance results
In order to test it I created 6 datasets with different sizes.
The total execution times for the python udf
and java udf
are:
num_rows | python [s] | java [s] |
---|---|---|
10^3 | 2.3486 | 0.4926 |
10^4 | 3.1402 | 0.5794 |
10^5 | 3.8764 | 0.7958 |
10^6 | 5.5529 | 1.9590 |
10^7 | 24.1512 | 11.5077 |
10^8 | 195.0326 | 88.8759 |
As we expected the java_udf
is faster than the python_udf
.
If we take a closer look at the values we can see that creating the python_udf
has a cost.
That is why the results are so bad for the small datasets.
Once the dataset is bigger that cost is no longer visible but there is a cost of moving the information from python
to java
.
That is why the pandas_udf
is around 2 times slower.