项目里需要对一个DataFrame,根据一个字段(country_id)新建出另一个字段(new_country_id),因此采用withColumn + udf的方式。但是country_id字段有null值,这使得udf失效。
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.LongType
object Main {
def main(args: Array[String]): Unit = {
val spark = new SparkSession.Builder().appName("planner")
.master("local[*]")
.config("spark.driver.host", "127.0.0.1")
.getOrCreate()
val myUDF = udf((countryID: Long) => {
3L
})
import spark.implicits._
val myDF = spark.sparkContext.parallelize(
Seq(1L, 2L)
).toDF("id")
.withColumn("country_id", when($"id" === 1 , lit(null).cast(LongType)).otherwise(lit(1)))
.withColumn("new_country_id", myUDF($"country_id"))
myDF.show(false)
}
}
以上代码尝试通过udf返回new_country_id字段的值都为3,但是结果却不尽人意:id为1的country_id是null,返回的new_country_id也是null而不是3。
+---+----------+--------------+
|id |country_id|new_country_id|
+---+----------+--------------+
|1 |null |null |
|2 |1 |3 |
+---+----------+--------------+
/**
* The analyzer should be aware of Scala primitive types so as to make the
* UDF return null if there is any null input value of these types. On the
* other hand, Java UDFs can only have boxed types, thus this will return
* Nil(has same effect with all false) and analyzer will skip null-handling
* on them.
*/
def inputPrimitives: Seq[Boolean] = {
inputEncoders.map { encoderOpt =>
// It's possible that some of the inputs don't have a specific encoder(e.g. `Any`)
if (encoderOpt.isDefined) {
val encoder = encoderOpt.get
if (encoder.isSerializedAsStruct) {
// struct type is not primitive
false
} else {
// `nullable` is false iff the type is primitive
!encoder.schema.head.nullable
}
} else {
// Any type is not primitive
false
}
}
}
/**
* Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the
* null check. When user defines a UDF with primitive parameters, there is no way to tell if the
* primitive parameter is null or not, so here we assume the primitive input is null-propagatable
* and we should return null if the input is null.
*/
object HandleNullInputsForUDF extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p if !p.resolved => p // Skip unresolved nodes.
case p => p transformExpressionsUp {
case udf @ ScalaUDF(_, _, inputs, _, _, _, _)
if udf.inputPrimitives.contains(true) =>
// Otherwise, add special handling of null for fields that can't accept null.
// The result of operations like this, when passed null, is generally to return null.
assert(udf.inputPrimitives.length == inputs.length)
val inputPrimitivesPair = udf.inputPrimitives.zip(inputs)
val inputNullCheck = inputPrimitivesPair.collect {
case (isPrimitive, input) if isPrimitive && input.nullable =>
IsNull(input)
}.reduceLeftOption[Expression](Or)
if (inputNullCheck.isDefined) {
// Once we add an `If` check above the udf, it is safe to mark those checked inputs
// as null-safe (i.e., wrap with `KnownNotNull`), because the null-returning
// branch of `If` will be called if any of these checked inputs is null. Thus we can
// prevent this rule from being applied repeatedly.
val newInputs = inputPrimitivesPair.map {
case (isPrimitive, input) =>
if (isPrimitive && input.nullable) {
KnownNotNull(input)
} else {
input
}
}
val newUDF = udf.copy(children = newInputs)
If(inputNullCheck.get, Literal.create(null, udf.dataType), newUDF)
} else {
udf
}
}
}
}
如果输入值是null,那么直接返回结果null,udf失效。
通过上面的分析,解决方案主要有以下两种。
将myUDF修改为如下:
val myUDF = udf((countryID: java.lang.Long) => {
3L
})
结果:
+---+----------+--------------+
|id |country_id|new_country_id|
+---+----------+--------------+
|1 |null |3 |
|2 |1 |3 |
+---+----------+--------------+
这种方案适用于自定义udf逻辑比较简单的情况。比如myDF构造的代码改成如下:
val myDF = spark.sparkContext.parallelize(
Seq(1L, 2L)
).toDF("id")
.withColumn("country_id", when($"id" === 1 , lit(null).cast(LongType)).otherwise(lit(1)))
// .withColumn("new_country_id", myUDF($"country_id"))
.withColumn("new_country_id", when($"country_id".isNull, lit(3)).otherwise(3))
结果:
+---+----------+--------------+
|id |country_id|new_country_id|
+---+----------+--------------+
|1 |null |3 |
|2 |1 |3 |
+---+----------+--------------+
这里when中的true case 和 false case的返回结果是一样的,是对应我们这个简单例子。而在实际开发中,true case 和 false case的返回结果大多数是不一样的。