SparkSql之用户自定义函数

为什么要自定义函数

虽然官方提供的sql函数已经很多,并且很强大了,但是有时候并不是都能满足我们的业务需求。除此之外,编写自定义函数能够让我们更加了解官方给定函数的底层实现。

函数的分类

sql函数一共分为三类

  • UDF[一条数据,一个结果]
    1)UDF:一行进入,一行出

  • UDAF[多条数据,一个结果,聚合函数]
    1)UDAF:输入多行,返回一行。
    2)Spark3.x推荐使用extends Aggregator自定义UDAF,属于强类型的Dataset方式。
    3)Spark2.x使用extends UserDefinedAggregateFunction,属于弱类型的DataFrame

  • UDTF[expload (spark不支持)]
    输入一行,返回多行(Hive);
    SparkSQL中没有UDTF,Spark中用flatMap即可实现该功能

如何自定义函数

步骤
1.定义一个函数
2.注册:sparkSession.udf.register("函数名称",对应的函数)
3.使用:在sql中进行使用

自定义UDF函数

需求:字符填充,长度由用户自定,填充字符由用户自定
如:customFill("tom","*",8) ;结果 *****tom

  1. 创建SparkSession
val sparkSession=SparkSession.builder().master("local[4]").appName("test").getOrCreate()
  1. 准备测试数据
    // 数据准备
    val list=List(
      Student(2,"绣花",16,"女",1),
      Student(5,"翠花",19,"女",2),
      Student(9,"王菲菲",20,"女",1),
      Student(11,"小惠",23,"女",1),
      Student(12,"梦雅",25,"女",3)
    )
// 为了方便,定义了一个样例类
case class Student(id:Int,name:String,age:Int,sex:String,classId:Int)
  1. 将数据注册成表
    // 导入隐式转换
    import sparkSession.implicits._

    // 转成 DataFrame
    val frame: DataFrame = list.toDF()

    // 注册成表
    frame.createOrReplaceTempView("student")
  1. 自定义函数
  /**
   * 自定义sql函数
   * @param coll 类名
   * @param symbol 符号
   * @param length 长度
   */
  def customFill(coll:String,symbol:String,length:Int): String ={

    if(coll.length>=length) coll
    else {
      symbol*(length-coll.length)+coll
    }
  }

5.注册自定义函数
name:第一个参数,给函数指定一个名称
func:将自定义函数传进去,
注意:以 def 声明的称为了方法,方法转函数(其实都是一个意思),需要在后面接上_

    //注册函数
    sparkSession.udf.register("customFill",customFill _)
  1. 编写sql,调用自定义函数并执行
// 编写sql,
    val frame1: DataFrame = sparkSession.sql(
      """
        |select id,customFill(name,'*',8) as name,age,sex,classId from  student
        |""".stripMargin)

    // 执行
    frame1.show()
  1. 运行结果
+---+-----------+---+---+-------+
| id|       name|age|sex|classId|
+---+-----------+---+---+-------+
|  2| ******绣花| 16| 女|      1|
|  5| ******翠花| 19| 女|      2|
|  9|*****王菲菲| 20| 女|      1|
| 11| ******小惠| 23| 女|      1|
| 12| ******梦雅| 25| 女|      3|
+---+-----------+---+---+-------+
  1. 完整代码
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.junit.Test

class SparkFunction {

  val sparkSession=SparkSession.builder().master("local[4]").appName("test").getOrCreate()

  @Test
  def demo01: Unit ={

    // 数据准备
    val list=List(
      Student(2,"绣花",16,"女",1),
      Student(5,"翠花",19,"女",2),
      Student(9,"王菲菲",20,"女",1),
      Student(11,"小惠",23,"女",1),
      Student(12,"梦雅",25,"女",3)
    )

    // 导入隐式转换
    import sparkSession.implicits._

    // 转成 DataFrame
    val frame: DataFrame = list.toDF()

    // 注册成表
    frame.createOrReplaceTempView("student")

    //注册函数
    sparkSession.udf.register("customFill",customFill _)

    // 编写sql
    val frame1: DataFrame = sparkSession.sql(
      """
        |select id,customFill(name,'*',8) as name,age,sex,classId from  student
        |""".stripMargin)

    // 执行
    frame1.show()

  }

 case class Student(id:Int,name:String,age:Int,sex:String,classId:Int)

  /**
   * 自定义sql函数
   * @param coll 类名
   * @param symbol 符号
   * @param length 长度
   */
  def customFill(coll:String,symbol:String,length:Int): String ={

    if(coll.length>=length) coll
    else {
      symbol*(length-coll.length)+coll
    }
  }

}

自定义UDAF函数

使用弱类型实现UDAF函数
步骤:

  1. 创建一个类

  2. 继承UserDefinedAggregateFunction抽象类(spark3.x版本中已标志为过期)

  3. 实现里面的抽象方法。

    //指定输入列的参数类型;需要指定为StructType类型
    override def inputSchema: StructType = ???
    //指定中间变量的类型
    override def bufferSchema: StructType = ???
    //指定聚合函数的结果类型
    override def dataType: DataType = ???
    //一致性指定(是否以同样输入返回同样的结果)
    override def deterministic: Boolean = ???
    //初始化中间变量的值
    override def initialize(buffer: MutableAggregationBuffer): Unit = ???
    //累加 [在每个task中执行]
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = ???
    //合并所有task中该分组的所有的数据
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = ???
    //计算得到最终结果
    override def evaluate(buffer: Row): Any = ???

  4. 创建自定义UDAF对象

  5. 注册自定义函数

  6. 编写sql并使用

数据准备

    // 数据准备
    val list=List(
      Student(2,"绣花",16,"女",1),
      Student(5,"翠花",19,"女",2),
      Student(9,"王菲菲",20,"女",1),
      Student(11,"小惠",23,"女",1),
      Student(12,"梦雅",25,"女",3)
    )

需求
统计用户的平均年龄(总年龄/总人数)

自定义函数

/**
 * 使用弱类型定义UDAF函数
 */
class CustomUdafByWeak extends UserDefinedAggregateFunction{

  /**
   * 指定输入列的参数类型;需要指定为`StructType`类型
   * @return
   */
  override def inputSchema: StructType = {
    // input 随便指定
    // LongType 输入进来的是年龄,所以需要指定IntegerType或LongType类型,
    val fields=Array(StructField("input",LongType))
    StructType(fields)
  }

  /**
   * 指定中间变量的类型
   * @return
   */
  override def bufferSchema: StructType = {
    // 当接收到输入的年龄时,肯定需要存起来,记录年龄总和(sum),次数(count)等,方便最终求平均年龄

    val fields=Array(
      //定义,记录总年龄
      StructField("sum",LongType),
      //定义,记录次数
      StructField("count",LongType)
    )
    StructType(fields)
  }

  /**
   * 指定聚合函数的结果类型
   * @return
   */
  override def dataType: DataType = DoubleType

  /**
   * 一致性指定
   * @return
   */
  override def deterministic: Boolean = true

  /**
   * 初始化中间变量的值
   * @param buffer
   */
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    // buffer中存放在中间变量数据

    // 在 bufferSchema 中定义了中间变量的类型,此时需要对中间变量进行设置
    // 默认的话,总年龄应该为0,总次数也应该为0

    //如何获取?sum 和 count ?
    // buffer.getAs[类型](根据角标取值)

    //如何设置值呢?
    // buffer(角标)= value
    // buffer.update(角标,value)

    buffer(0)= 0L // 总年龄
    buffer(1)= 0L // 总次数

  }

  /**
   * 累加 [在每个task中执行]
   * @param buffer
   * @param input
   */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    // buffer中存放在中间变量数据
    // input 当前输入的年龄

    // 获取 上一次sum
    val preSum:Long = buffer.getAs[Long](0)

    // 获取 上一次count
    val preCount: Long = buffer.getAs[Long](1)

    // 从input中取出年龄 在 inputSchema函数中只指定了一个参数,所以用角标0取值即可。
    val age=input.getAs[Long](0)


    //重新修改值
    buffer.update(0,preSum+age)
    buffer.update(1,preCount+1)

  }

  /**
   * 合并所有task中该分组的所有的sum与count
   * @param buffer1
   * @param buffer2
   */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

    // 同样的操作,取值赋值

    // 获取 上一次sum
    val preSum:Long = buffer1.getAs[Long](0)

    // 获取 上一次count
    val preCount: Long = buffer1.getAs[Long](1)

    // 取各个分区的 sum 和 count
    val partitionSum:Long = buffer2.getAs[Long](0)
    val partitionCount:Long = buffer2.getAs[Long](1)

    // 累加保存
    buffer1.update(0,preSum+partitionSum)
    buffer1.update(1,preCount+partitionCount)
  }

  /**
   * 计算得到最终结果
   * @param buffer
   * @return
   */
  override def evaluate(buffer: Row): Any = {

    // 计算平均年龄
    // sum / count =avg

    // 总年龄
    val sum=buffer.getAs[Long](0)
    // 总次数
    val count=buffer.getAs[Long](1)

    sum.toDouble/count
  }
}

测试

  @Test
  def demo02(): Unit ={

  val sparkSession=SparkSession.builder().master("local[4]").appName("test").getOrCreate()

  // 数据准备
  val list=List(
    Student(2,"绣花",16,"女",1),
    Student(5,"翠花",19,"女",2),
    Student(9,"王菲菲",20,"女",1),
    Student(11,"小惠",23,"女",1),
    Student(12,"梦雅",25,"女",3)
  )

    // 导入隐式转换
    import sparkSession.implicits._

    // 注册成表
    val df: DataFrame = list.toDF("id","name","age","sex","class_id")
    df.createOrReplaceTempView("student")

    // 创建自定义UDAF对象
    val fun=new CustomUdafByWeak

    // 注册
    sparkSession.udf.register("custom_avg",fun)

    // 编写sql
    val df2: DataFrame = sparkSession.sql(
      """
        |select custom_avg(age) from student
        |""".stripMargin)

    df2.show()

  }

结果

+-------------------------------------+
|customudafbyweak(CAST(age AS BIGINT))|
+-------------------------------------+
|                                 20.6|
+-------------------------------------+

使用强类型实现UDAF函数

@Stable
@deprecated("Aggregator[IN, BUF, OUT] should now be registered as a UDF" +
  " via the functions.udaf(agg) method.", "3.0.0")
abstract class UserDefinedAggregateFunction extends Serializable {...}

在spark 3.x中UserDefinedAggregateFunction已经被弃用了,目前推荐的是使用Aggregator[IN, BUF, OUT]

import org.apache.spark.sql.expressions.Aggregator

它需要我们指定三个类型(参数语义和上面是一样的)
IN:输入类型
BUF:中间类型
OUT:最终输出类型

步骤:

  1. 定义class

  2. 继承Aggregator 指定 INBUFOUT 参数类型

  3. 重写内部方法

    // 初始化中间变量
    override def zero: ParamBuff = ???
    // 在每个分区中针对每个组进行合并
    override def reduce(b: ParamBuff, a: Long): ParamBuff = ???
    // 在新的RDD分区中针对每个组的所有父RDD分区结果进行合并
    override def merge(b1: ParamBuff, b2: ParamBuff): ParamBuff = ???
    // 最终结果计算
    override def finish(reduction: ParamBuff): Double = ???
    // 指定中间变量的编码方式
    override def bufferEncoder: Encoder[ParamBuff] = ???
    // 指定结果类型的编码方式
    override def outputEncoder: Encoder[Double] = ???

  4. 创建自定义UDAF类

val fun = new CustomUdafByStrong
  1. 导入import org.apache.spark.sql.functions._ 转换成 udaf
    // 转换成udaf
    import org.apache.spark.sql.functions._
    // 创建自定义UDAF对象
    val func = udaf(fun)
  1. 注册
parkSession.udf.register("custom_avg",func)
  1. 调用

自定义UDAF函数

/**
 * 中间变量需要两个参数,
 * @param sum 计算年龄总数
 * @param count // 计算年龄个数
 */
case class ParamBuff(sum:Long,count:Long)
/**
 * 使用强类型定义UDAF函数
 */
class CustomUdafByStrong extends Aggregator[Long,ParamBuff,Double]{
  /**
   * 初始化中间变量
   * @return
   */
  override def zero: ParamBuff = {
    ParamBuff(0L,0L)
  }

  /**
   * 在每个分区中针对每个组进行合并
   * @param b  ParamBuff 样例类
   * @param a  a 传入进来的年龄值
   * @return
   */
  override def reduce(b: ParamBuff, a: Long): ParamBuff = {
    // 获取总年龄
    ParamBuff(b.sum+a,b.count+1)
  }

  /**
   * 在新的RDD分区中针对每个组的所有父RDD分区结果进行合并
   * @param b1
   * @param b2
   * @return
   */
  override def merge(b1: ParamBuff, b2: ParamBuff): ParamBuff = {
    ParamBuff(b1.sum+b2.sum,b1.count+b2.count)
  }

  /**
   * 最终结果计算
   * @param reduction
   * @return
   */
  override def finish(reduction: ParamBuff): Double = {
    reduction.sum.toDouble/reduction.count
  }

  /**
   * 指定中间变量的编码方式
   * @return
   */
  override def bufferEncoder: Encoder[ParamBuff] = Encoders.product[ParamBuff]

  /**
   * 指定结果类型的编码方式
   * @return
   */
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

测试

  @Test
  def demo03(): Unit ={


    // 导入隐式转换
    import sparkSession.implicits._

    // 注册成表
    val df: DataFrame = list.toDF("id","name","age","sex","class_id")
    df.createOrReplaceTempView("student")


    // 转换成udaf
    import org.apache.spark.sql.functions._
    // 创建自定义UDAF对象
    val func = udaf(new CustomUdafByStrong)

    // 注册
    sparkSession.udf.register("custom_avg",func)

    // 编写sql
    val df2: DataFrame = sparkSession.sql(
      """
        |select custom_avg(age) from student
        |""".stripMargin)

    df2.show()

  }

结果

+---------------------------------------+
|customudafbystrong(CAST(age AS BIGINT))|
+---------------------------------------+
|                                   20.6|
+---------------------------------------+

你可能感兴趣的:(SparkSql之用户自定义函数)