Spark sql实现自定义函数
有小伙伴可能会疑问:Spark Sql提供了编写UDF和UDAF的接口扩展,为什么还有开发自定义函数呢?
虽然Spark SQL 提供了UDF和UDAF,但是当我们想要实现 原生函数一样的功能比如:语义参数 ,可变参数等 功能时候,UDF和UDAF就无法满足。
例如 我们想要实现类似于substr这样的函数, udf就无法实现, 其中的参数 ‘Spark SQL’ FROM 5、还有后面两个参数中最后一个可有可无的情况下。
> SELECT substr('Spark SQL', 5);
k SQL
> SELECT substr('Spark SQL', -3);
SQL
> SELECT substr('Spark SQL', 5, 1);
k
> SELECT substr('Spark SQL' FROM 5);
k SQL
> SELECT substr('Spark SQL' FROM -3);
SQL
> SELECT substr('Spark SQL' FROM 5 FOR 1);
k
``
spark 官网提供了 SparkSessionExtensions类 ,可以自定义的增强和扩展Spark的很多能力,例如: injectOptimizerRule、injectOptimizerRule等等。
为什么会有这样的需求呢?
原因是我想要解决Spark SQl 中的一些函数不完全满足我想要的功能。
比如:原生的spark Sql 函数to_timestamp 在执行有些参数的时候因为数据的格式和指定的parrten不匹配导致运行为null (严格模式下会报错)
我期望的结果应该为:2020-08-08 00:00:00,而不是为null, 简言之就是parrten只要是正确的时间格式,就应该解析出来。
这里是我们的需求,如果各位其他的需求 spark Sql 中的函数不是完全满足,通过UDF能实现,就用UDF实现,或者不完全满足 就跟我这个例子一样进行重写覆盖,如果完全没有 也可以按照这个逻辑自己定义一个全新的函数实现。
解决思路:
老套路,跟踪源码找到 报null和报错的代码逻辑,开发函数,重写逻辑,然后覆盖原函数。
问题代码如下:
1.ToTimestamp的eval方法
case StringType =>
val fmt = right.eval(input)
if (fmt == null) {
null
} else {
val formatter = formatterOption.getOrElse(getFormatter(fmt.toString))
try {
formatter.parse(t.asInstanceOf[UTF8String].toString) / downScaleFactor
} catch {
case e if isParseError(e) =>
if (failOnError) {
throw e
} else {
null
}
}
可以看出解析失败 直接catch,根据failOnError 是否为严格模式报错还是返回null
2.ToTimestamp的doGenCode方法
case StringType => formatterOption.map {
fmt =>
val df = classOf[TimestampFormatter].getName
val formatterName = ctx.addReferenceObj("formatter", fmt, df)
nullSafeCodeGen(ctx, ev, (datetimeStr, _) =>
s"""
|try {
| ${ev.value} = $formatterName.parse($datetimeStr.toString()) / $downScaleFactor;
|} catch (java.time.DateTimeException e) {
| $parseErrorBranch
|} catch (java.time.format.DateTimeParseException e) {
| $parseErrorBranch
|} catch (java.text.ParseException e) {
| $parseErrorBranch
|}
|""".stripMargin)
}
这里是拼接java代码的逻辑,逻辑和eval方法相同。
解决
1.开发逻辑
新建一个样例类继承ToTimestamp,重写上述的逻辑代码
解决思路: 当获取异常后,判断如果是应为格式问题解释失败,识别数据格式,将数据按照数据的格式解析成时间,然后再将时间类型的数据,解析成用户指定的字符串格式。详情看代码。
package v2.jdbc.spark.expressions.function
import java.text.ParseException
import java.time.format.DateTimeParseException
import java.time.{
DateTimeException, ZoneId}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.codegen.{
CodeGenerator, CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.{
Cast, Expression, TimeZoneAwareExpression, ToTimestamp}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMicros
import org.apache.spark.sql.catalyst.util.{
LegacyDateFormats, TimestampFormatter}
import org.apache.spark.sql.catalyst.{
FunctionIdentifier, InternalRow}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{
DataType, DateType, StringType, TimestampType}
import org.apache.spark.unsafe.types.UTF8String
import v2.jdbc.spark.expressions.extra.{
ExpressionUtils, FunctionDescription}
import v2.jdbc.spark.expressions.function.DateTimeUtils.dateStrChangeFormat
case class BiGetTimestamp(
left: Expression,
right: Expression,
timeZoneId: Option[String] = None,
failOnError: Boolean = SQLConf.get.ansiEnabled)extends ToTimestamp {
override val downScaleFactor = 1
override def dataType: DataType = TimestampType
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
private def isParseError(e: Throwable): Boolean = e match {
case _: DateTimeParseException |
_: DateTimeException |
_: RuntimeException |
_: ParseException => true
case _ => false
}
override def eval(input: InternalRow): Any = {
val t = left.eval(input)
if (t == null) {
null
} else {
left.dataType match {
case DateType =>
daysToMicros(t.asInstanceOf[Int], zoneId) / downScaleFactor
case TimestampType =>
t.asInstanceOf[Long] / downScaleFactor
case StringType =>
val fmt = right.eval(input)
if (fmt == null) {
null
} else {
val formatter = formatterOption.getOrElse(getFormatter(fmt.toString))
try {
formatter.parse(t.asInstanceOf[UTF8String].toString) / downScaleFactor
} catch {
case e if isParseError(e)=>
val dateStr =UTF8String.fromString(dateStrChangeFormat(t.toString, fmt.toString)).toString
formatter.parse(dateStr) / downScaleFactor
case other=>
if (failOnError) {
throw other
} else {
null
}
}
}
}
}
}
def doGenCodeErrorProcess(str1:String, datetimeStr: String,
ev: ExprCode,formatterName:String,pattern:String):String={
s"""
|
| boolean year = false;
| $str1
| if(pattern.matcher($datetimeStr.toString().substring(0, 4)).matches()) {
| year = true;
| }
| StringBuilder sb = new StringBuilder();
| int index = 0;
| if(!year) {
| if($datetimeStr.toString().contains("月") || $datetimeStr.toString().contains("-") || $datetimeStr.toString().contains("/")) {
| if(Character.isDigit($datetimeStr.toString().charAt(0))) {
| index = 1;
| }
| }else {
| index = 3;
| }
| }
| for (int i = 0; i < $datetimeStr.toString().length(); i++) {
| char chr = $datetimeStr.toString().charAt(i);
| if(Character.isDigit(chr)) {
| if(index==0) {
| sb.append("y");
| }
| if(index==1) {
| sb.append("M");
| }
| if(index==2) {
| sb.append("d");
| }
| if(index==3) {
| sb.append("H");
| }
| if(index==4) {
| sb.append("m");
| }
| if(index==5) {
| sb.append("s");
| }
| if(index==6) {
| sb.append("S");
| }
| }else {
| if(i>0) {
| char lastChar = $datetimeStr.toString().charAt(i-1);
| if(Character.isDigit(lastChar)) {
| index++;
| }
| }
| sb.append(chr);
| }
| }
| java.text.SimpleDateFormat simpleDateFormat = new java.text.SimpleDateFormat(sb.toString());
| java.util.Date date = simpleDateFormat.parse($datetimeStr.toString());
| java.text.SimpleDateFormat simpleDateFormat2 = new java.text.SimpleDateFormat("$pattern");
| ${ev.value} = $formatterName.parse(simpleDateFormat2.format(date)) / $downScaleFactor;
|""".stripMargin
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = CodeGenerator.javaType(dataType)
val parseErrorBranch = if (failOnError) "throw e;" else s"${ev.isNull} = true;"
val code = left.dataType match {
case StringType => formatterOption.map {
fmt =>
val df = classOf[TimestampFormatter].getName
val formatterName = ctx.addReferenceObj("formatter", fmt, df)
val patternField = fmt.getClass.getDeclaredField("pattern")
patternField.setAccessible(true)
val pattern = patternField.get(fmt).toString
val str1="java.util.regex.Pattern pattern = java.util.regex.Pattern.compile(\"^[-\\\\+]?[\\\\d]*$\");"
nullSafeCodeGen(ctx, ev, (datetimeStr, _) =>
s"""
|try {
| ${ev.value} = $formatterName.parse($datetimeStr.toString()) / $downScaleFactor;
|} catch (java.time.DateTimeException e) {
| ${doGenCodeErrorProcess(str1,datetimeStr,ev,formatterName,pattern)}
|} catch (java.time.format.DateTimeParseException e) {
| ${doGenCodeErrorProcess(str1,datetimeStr,ev,formatterName,pattern)}
|} catch (java.text.ParseException e) {
| ${doGenCodeErrorProcess(str1,datetimeStr,ev,formatterName,pattern)}
|} catch (java.lang.RuntimeException e) {
| ${doGenCodeErrorProcess(str1, datetimeStr, ev, formatterName, pattern)}
|} catch (java.lang.Exception e) {
| $parseErrorBranch
|}
|""".stripMargin)
}.getOrElse {
val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
val tf = TimestampFormatter.getClass.getName.stripSuffix("$")
val ldf = LegacyDateFormats.getClass.getName.stripSuffix("$")
val timestampFormatter = ctx.freshName("timestampFormatter")
nullSafeCodeGen(ctx, ev, (string, format) =>
s"""
|$tf $timestampFormatter = $tf$$.MODULE$$.apply(
| $format.toString(),
| $zid,
| $ldf$$.MODULE$$.SIMPLE_DATE_FORMAT(),
| true);
|try {
| ${ev.value} = $timestampFormatter.parse($string.toString()) / $downScaleFactor;
|} catch (java.time.format.DateTimeParseException e) {
| $parseErrorBranch
|} catch (java.time.DateTimeException e) {
| $parseErrorBranch
|} catch (java.text.ParseException e) {
| $parseErrorBranch
|}
|""".stripMargin)
}
case TimestampType =>
val eval1 = left.genCode(ctx)
ev.copy(code =
code"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = ${eval1.value} / $downScaleFactor;
}""")
case DateType =>
val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
val eval1 = left.genCode(ctx)
ev.copy(code =
code"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $dtu.daysToMicros(${eval1.value}, $zid) / $downScaleFactor;
}""")
}
code
}
}
object BiGetTimestamp {
val fd: FunctionDescription = (
new FunctionIdentifier("to_timestamp"),
ExpressionUtils.getExpressionInfo(classOf[BiGetTimestamp], "to_timestamp"),
(children: Seq[Expression]) =>
children.size match {
case 1=>
Cast(children.head, TimestampType,Some("Asia/Shanghai"))
case 2=>
BiGetTimestamp(children.head,children(1),Some("Asia/Shanghai"))
case _=>throw new Exception("参数异常")
}
)
val fd_toDte: FunctionDescription = (
new FunctionIdentifier("to_date"),
ExpressionUtils.getExpressionInfo(classOf[BiGetTimestamp], "to_date"),
(children: Seq[Expression]) =>
children.size match {
case 1=>
Cast(children.head, DateType,Some("Asia/Shanghai"))
case 2=>
Cast(BiGetTimestamp(children.head,children(1),Some("Asia/Shanghai")), DateType,Some("Asia/Shanghai"))
case _=>throw new Exception("参数异常")
}
)
}
package v2.jdbc.spark.expressions.function
import java.text.SimpleDateFormat
import java.util.regex.Pattern
object DateTimeUtils {
/**
* 识别日期字符串的日期格式
*/
def identifyDateType(str: String): String = {
var year = false
val pattern = Pattern.compile("^[-\\+]?[\\d]*$")
if (pattern.matcher(str.substring(0, 4)).matches) year = true
val sb = new StringBuilder
var index = 0
if (!year) if (str.contains("月") || str.contains("-") || str.contains("/")) if (Character.isDigit(str.charAt(0))) index = 1
else index = 3
for (i <- 0 until str.length) {
val chr = str.charAt(i)
if (Character.isDigit(chr)) {
if (index == 0) sb.append("y")
if (index == 1) sb.append("M")
if (index == 2) sb.append("d")
if (index == 3) sb.append("H")
if (index == 4) sb.append("m")
if (index == 5) sb.append("s")
if (index == 6) sb.append("S")
}
else {
if (i > 0) {
val lastChar = str.charAt(i - 1)
if (Character.isDigit(lastChar)) index += 1
}
sb.append(chr)
}
}
sb.toString
}
def dateStrChangeFormat(dateStr: String, targetFormat: String): String = {
val sourceFormat = new SimpleDateFormat(identifyDateType(dateStr))
val date = sourceFormat.parse(dateStr)
val sourceFormat2 = new SimpleDateFormat(targetFormat)
sourceFormat2.format(date)
}
}
2.注册函数
函数注册 有两种方式:
1.直接在构建Spark Session时候通过withExtensions直接使用。
2.不直接使用,通过SparkConf配置。
在配置参数中配置:
spark.sql.extensions=v2.jdbc.spark.expressions.extra.FunctionExtensions
这里是我们的需求改写并覆盖原有的函数,如果各位其他的需求 spark Sql 中的函数不是完全满足,通过UDF能实现,就用UDF实现,或者不完全满足 就跟我这个例子一样进行重写覆盖,如果完全没有 也可以按照这个逻辑 根据要实现函数的类型进行继承对应的Expression,编写eval和doGenCode方法, 自己定义一个全新的函数。