本文基于spark 3.2.0
由于codegen涉及到的知识点比较多,我们先来说清楚code""""""
,我们暂且叫做code代码块
要想搞清楚spark的code代码块,就得现搞清楚scala 字符串插值。
scala 字符串插值是2.10.0版本引用进来的新语法规则,可以直接允许使用者将变量引用直接插入到字符串中,如下:
val name = 'LI'
println(s"My name is $name")
输出:
My name is LI
这种资料很多,大家自行查阅资料理解。
因为这块代码比较复杂,直接拿出例子来运行:
直接找到spark CastSuite.scala 第215行如下:
test("cast string to boolean II") {
checkEvaluation(cast("abc", BooleanType), null)
之后在javaCode.scala 输出对应的想要debug的值,如下:
*/
def code(args: Any*): Block = {
sc.checkLengths(args)
if (sc.parts.length == 0) {
EmptyBlock
} else {
args.foreach {
case _: ExprValue | _: Inline | _: Block =>
case _: Boolean | _: Byte | _: Int | _: Long | _: Float | _: Double | _: String =>
case other => throw QueryExecutionErrors.cannotInterpolateClassIntoCodeBlockError(other)
}
val (codeParts, blockInputs) = foldLiteralArgs(sc.parts, args)
// scalasytle:off
println(s"code: $codeParts")
println(s"blockInputs: $blockInputs")
// scalasytle:on
CodeBlock(codeParts, blockInputs)
}
}
这样,运行后我们会发现,如下结果:
code: ArrayBuffer(
if (org.apache.spark.sql.catalyst.util.StringUtils.isTrueString(, )) {
, = true;
} else if (org.apache.spark.sql.catalyst.util.StringUtils.isFalseString(, )) {
, = false;
} else {
isNull_0 = true;
}
)
blockInputs: ArrayBuffer(((UTF8String) references[0] /* literal */), value_0, ((UTF8String) references[0] /* literal */), value_0)
result: if (org.apache.spark.sql.catalyst.util.StringUtils.isTrueString(((UTF8String) references[0] /* literal */))) {
value_0 = true;
} else if (org.apache.spark.sql.catalyst.util.StringUtils.isFalseString(((UTF8String) references[0] /* literal */))) {
value_0 = false;
} else {
isNull_0 = true;
}
...
而这段代码刚好和Cast.scala中的 castToBooleanCode方法是一一对应的的:
private[this] def castToBooleanCode(from: DataType): CastFunction = from match {
case StringType =>
val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}"
(c, evPrim, evNull) =>
val castFailureCode = if (ansiEnabled) {
s"throw QueryExecutionErrors.invalidInputSyntaxForBooleanError($c);"
} else {
s"$evNull = true;"
}
val result = code"""
if ($stringUtils.isTrueString($c)) {
$evPrim = true;
} else if ($stringUtils.isFalseString($c)) {
$evPrim = false;
} else {
$castFailureCode
}
"""
// scalastyle:off
println(s"result: $result")
// scalastyle:on
result
也就是说spark自定义的ExprValue类型的值被替换了(其实是Inline/Block/ExprValue这三种类型的值都会被替换,只不过这里没有体现),如下:
x | x |
---|---|
evPrim | 被替换成了((UTF8String) references[0] /* literal */) |
c | 被替换成了value_0 |
而输出的result结果就是拼接完后的完整字符串。
我们这里是为了debug,才会把结果和对应的片段打印出来,
而在spark真正处理的时候,返回的是ExprCode类型的值,在真正需要代码生成的时候,才会调用的toString的方法生成对应的字符串
但是我们在Cast.scala的方法中我们看到的doGenCode是先调用child.genCode的方法的:
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
ev.copy(code = eval.code +
castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast))
}
那子节点的ExprCode怎么和父节点的ExprCode连接起来的呢?
其实这个和写代码的思路是一样的,每个子节点返回的ExprCode类型的值,都会对应为该方法体的的实现代码,返回值(包括了类型),spark额外增加了一个是否为null,如下:
case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue)
其中code是对应的方法体的实现代码,
isNull 是对应的是否为null,
value 代表的返回值
至于为什么会额外增加一个是否为null,还是和写代码的逻辑是一样的,因为只有不为空的情况下,代码才会正常的往下运行:
protected[this] def castCode(ctx: CodegenContext, input: ExprValue, inputIsNull: ExprValue,
result: ExprValue, resultIsNull: ExprValue, resultType: DataType, cast: CastFunction): Block = {
val javaType = JavaCode.javaType(resultType)
code"""
boolean $resultIsNull = $inputIsNull;
$javaType $result = ${CodeGenerator.defaultValue(resultType)};
if (!$inputIsNull) {
${cast(input, result, resultIsNull)}
}
"""
}
这里的!$inputIsNull判断,只有不为空了才进行下一步的转换操作,要不然会抛出异常。
这样把子节点的结果作为父节点的入参传入给对应的方法,这样生成的代码完全符合编码的逻辑,这样这部分也就说完了,当然这部分也是代码生成的重中之重,理解了这部分,代码生成这块就差不多了,其他的就是各个部分的实现,用心去看即可。