Cast 强制类型转换发生在 Logical Plan 转成 Analyzed Logical Plan阶段,
根据表达式 override def inputTypes() 方法进行校验,然后
childrenResolved 最终和 inputTypes 进行校验
override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
/**
* Casts types according to the expected input types for [[Expression]]s.
*/
object ImplicitTypeCasts extends TypeCoercionRule {
...
/**
* Given an expected data type, try to cast the expression and return the cast expression.
*
* If the expression already fits the input type, we simply return the expression itself.
* If the expression has an incompatible type that cannot be implicitly cast, return None.
*/
def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = {
implicitCast(e.dataType, expectedType).map { dt =>
if (dt == e.dataType) e else Cast(e, dt)
}
}
private def implicitCast(inType: DataType, expectedType: AbstractDataType): Option[DataType] = {
// Note that ret is nullable to avoid typing a lot of Some(...) in this local scope.
// We wrap immediately an Option after this.
@Nullable val ret: DataType = (inType, expectedType) match {
// If the expected type is already a parent of the input type, no need to cast.
case _ if expectedType.acceptsType(inType) => inType
// Cast null type (usually from null literals) into target types
case (NullType, target) => target.defaultConcreteType
// If the function accepts any numeric type and the input is a string, we follow the hive
// convention and cast that input into a double
case (StringType, NumericType) => NumericType.defaultConcreteType
// Implicit cast among numeric types. When we reach here, input type is not acceptable.
// If input is a numeric type but not decimal, and we expect a decimal type,
// cast the input to decimal.
case (d: NumericType, DecimalType) => DecimalType.forType(d)
// For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long
case (_: NumericType, target: NumericType) => target
// Implicit cast between date time types
case (DateType, TimestampType) => TimestampType
case (TimestampType, DateType) => DateType
// Implicit cast from/to string
case (StringType, DecimalType) => DecimalType.SYSTEM_DEFAULT
case (StringType, target: NumericType) => target
case (StringType, DateType) => DateType
case (StringType, TimestampType) => TimestampType
case (StringType, BinaryType) => BinaryType
// Cast any atomic type to string.
case (any: AtomicType, StringType) if any != StringType => StringType
// When we reach here, input type is not acceptable for any types in this type collection,
// try to find the first one we can implicitly cast.
case (_, TypeCollection(types)) =>
types.flatMap(implicitCast(inType, _)).headOption.orNull
// Implicit cast between array types.
//
// Compare the nullabilities of the from type and the to type, check whether the cast of
// the nullability is resolvable by the following rules:
// 1. If the nullability of the to type is true, the cast is always allowed;
// 2. If the nullability of the to type is false, and the nullability of the from type is
// true, the cast is never allowed;
// 3. If the nullabilities of both the from type and the to type are false, the cast is
// allowed only when Cast.forceNullable(fromType, toType) is false.
case (ArrayType(fromType, fn), ArrayType(toType: DataType, true)) =>
implicitCast(fromType, toType).map(ArrayType(_, true)).orNull
case (ArrayType(fromType, true), ArrayType(toType: DataType, false)) => null
case (ArrayType(fromType, false), ArrayType(toType: DataType, false))
if !Cast.forceNullable(fromType, toType) =>
implicitCast(fromType, toType).map(ArrayType(_, false)).orNull
// Implicit cast between Map types.
// Follows the same semantics of implicit casting between two array types.
// Refer to documentation above. Make sure that both key and values
// can not be null after the implicit cast operation by calling forceNullable
// method.
case (MapType(fromKeyType, fromValueType, fn), MapType(toKeyType, toValueType, tn))
if !Cast.forceNullable(fromKeyType, toKeyType) && Cast.resolvableNullability(fn, tn) =>
if (Cast.forceNullable(fromValueType, toValueType) && !tn) {
null
} else {
val newKeyType = implicitCast(fromKeyType, toKeyType).orNull
val newValueType = implicitCast(fromValueType, toValueType).orNull
if (newKeyType != null && newValueType != null) {
MapType(newKeyType, newValueType, tn)
} else {
null
}
}
case _ => null
}
Option(ret)
}
...
}
override def checkInputDataTypes(): TypeCheckResult = {
ExpectsInputTypes.checkInputDataTypes(children, inputTypes)
}
object ExpectsInputTypes {
def checkInputDataTypes(
inputs: Seq[Expression],
inputTypes: Seq[AbstractDataType]): TypeCheckResult = {
val mismatches = inputs.zip(inputTypes).zipWithIndex.collect {
case ((input, expected), idx) if !expected.acceptsType(input.dataType) =>
s"argument ${idx + 1} requires ${expected.simpleString} type, " +
s"however, '${input.sql}' is of ${input.dataType.catalogString} type."
}
if (mismatches.isEmpty) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(mismatches.mkString(" "))
}
}
}
这里以ShiftLeft 举例
/**
* Bitwise left shift.
*
* @param left the base number to shift.
* @param right number of bits to left shift.
*/
@ExpressionDescription(
usage = "_FUNC_(base, expr) - Bitwise left shift.",
examples = """
Examples:
> SELECT _FUNC_(2, 1);
4
""",
since = "1.5.0")
case class ShiftLeft(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(IntegerType, LongType), IntegerType)
override def dataType: DataType = left.dataType
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
input1 match {
case l: jl.Long => l << input2.asInstanceOf[jl.Integer]
case i: jl.Integer => i << input2.asInstanceOf[jl.Integer]
}
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (left, right) => s"$left << $right")
}
}
zip为按顺序一一对应
scala> val numbers = Seq(0, 1, 2, 3, 4)
numbers: Seq[Int] = List(0, 1, 2, 3, 4)
scala> val series = Seq(10, 11, 12, 13, 14)
series: Seq[Int] = List(10, 11, 12, 13, 14)
scala> numbers zip series
res0: Seq[(Int, Int)] = List((0,10), (1,11), (2,12), (3,13), (4,14))
scala> numbers.zip(series)
res1: Seq[(Int, Int)] = List((0,10), (1,11), (2,12), (3,13), (4,14))
如果某一个集合多余,会去掉多余的
比如:
scala> val series = Seq(10, 11, 12, 13, 14, 15)
series: Seq[Int] = List(10, 11, 12, 13, 14, 15)
scala> numbers zip series
res2: Seq[(Int, Int)] = List((0,10), (1,11), (2,12), (3,13), (4,14))