Spark之UDF失效

背景

项目里需要对一个DataFrame,根据一个字段(country_id)新建出另一个字段(new_country_id),因此采用withColumn + udf的方式。但是country_id字段有null值,这使得udf失效。

代码

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.LongType


object Main {
  def main(args: Array[String]): Unit = {
    val spark = new SparkSession.Builder().appName("planner")
      .master("local[*]")
      .config("spark.driver.host", "127.0.0.1")
      .getOrCreate()
    val myUDF = udf((countryID: Long) => {
      3L
    })
    import spark.implicits._
    val myDF = spark.sparkContext.parallelize(
      Seq(1L, 2L)
    ).toDF("id")
      .withColumn("country_id", when($"id" === 1 , lit(null).cast(LongType)).otherwise(lit(1)))
      .withColumn("new_country_id", myUDF($"country_id"))
    myDF.show(false)
  }
}

以上代码尝试通过udf返回new_country_id字段的值都为3,但是结果却不尽人意:id为1的country_id是null,返回的new_country_id也是null而不是3。

+---+----------+--------------+
|id |country_id|new_country_id|
+---+----------+--------------+
|1  |null      |null          |
|2  |1         |3             |
+---+----------+--------------+

原因

  1. 在catalyst会去检查字段类型是否是原生类型(是否不可空):org.apache.spark.sql.catalyst.expressions.ScalaUDF#inputPrimitives
/**
   * The analyzer should be aware of Scala primitive types so as to make the
   * UDF return null if there is any null input value of these types. On the
   * other hand, Java UDFs can only have boxed types, thus this will return
   * Nil(has same effect with all false) and analyzer will skip null-handling
   * on them.
   */
  def inputPrimitives: Seq[Boolean] = {
    inputEncoders.map { encoderOpt =>
      // It's possible that some of the inputs don't have a specific encoder(e.g. `Any`)
      if (encoderOpt.isDefined) {
        val encoder = encoderOpt.get
        if (encoder.isSerializedAsStruct) {
          // struct type is not primitive
          false
        } else {
          // `nullable` is false iff the type is primitive
          !encoder.schema.head.nullable
        }
      } else {
        // Any type is not primitive
        false
      }
    }
  }
  1. 如果字段类型是不可空的,紧接着会去检查字段值是否为空:org.apache.spark.sql.catalyst.analysis.Analyzer.HandleNullInputsForUDF#apply
/**
   * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the
   * null check.  When user defines a UDF with primitive parameters, there is no way to tell if the
   * primitive parameter is null or not, so here we assume the primitive input is null-propagatable
   * and we should return null if the input is null.
   */
  object HandleNullInputsForUDF extends Rule[LogicalPlan] {
    override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
      case p if !p.resolved => p // Skip unresolved nodes.

      case p => p transformExpressionsUp {

        case udf @ ScalaUDF(_, _, inputs, _, _, _, _)
            if udf.inputPrimitives.contains(true) =>
          // Otherwise, add special handling of null for fields that can't accept null.
          // The result of operations like this, when passed null, is generally to return null.
          assert(udf.inputPrimitives.length == inputs.length)

          val inputPrimitivesPair = udf.inputPrimitives.zip(inputs)
          val inputNullCheck = inputPrimitivesPair.collect {
            case (isPrimitive, input) if isPrimitive && input.nullable =>
              IsNull(input)
          }.reduceLeftOption[Expression](Or)

          if (inputNullCheck.isDefined) {
            // Once we add an `If` check above the udf, it is safe to mark those checked inputs
            // as null-safe (i.e., wrap with `KnownNotNull`), because the null-returning
            // branch of `If` will be called if any of these checked inputs is null. Thus we can
            // prevent this rule from being applied repeatedly.
            val newInputs = inputPrimitivesPair.map {
              case (isPrimitive, input) =>
                if (isPrimitive && input.nullable) {
                  KnownNotNull(input)
                } else {
                  input
                }
            }
            val newUDF = udf.copy(children = newInputs)
            If(inputNullCheck.get, Literal.create(null, udf.dataType), newUDF)
          } else {
            udf
          }
      }
    }
  }

如果输入值是null,那么直接返回结果null,udf失效。

解决方案

通过上面的分析,解决方案主要有以下两种。

方案一 将country_id字段定义为java包装类型(java.lang.Long)

将myUDF修改为如下:

val myUDF = udf((countryID: java.lang.Long) => {
    3L
})

结果:

+---+----------+--------------+
|id |country_id|new_country_id|
+---+----------+--------------+
|1  |null      |3             |
|2  |1         |3             |
+---+----------+--------------+

方案二 不使用自定义udf,通过when函数完成转化

这种方案适用于自定义udf逻辑比较简单的情况。比如myDF构造的代码改成如下:

val myDF = spark.sparkContext.parallelize(
      Seq(1L, 2L)
    ).toDF("id")
      .withColumn("country_id", when($"id" === 1 , lit(null).cast(LongType)).otherwise(lit(1)))
//      .withColumn("new_country_id", myUDF($"country_id"))
        .withColumn("new_country_id", when($"country_id".isNull, lit(3)).otherwise(3))

结果:

+---+----------+--------------+
|id |country_id|new_country_id|
+---+----------+--------------+
|1  |null      |3             |
|2  |1         |3             |
+---+----------+--------------+

这里when中的true case 和 false case的返回结果是一样的,是对应我们这个简单例子。而在实际开发中,true case 和 false case的返回结果大多数是不一样的。

你可能感兴趣的:(Spark,Scala,spark,scala,大数据)