SparkSQL 用户自定义函数

SparkSQL允许用户可以通过spark.udf功能添加自定义函数,实现自定义功能。

一、UDF

1. 创建DataFrame

scala> val df = spark.read.json("data/user.json")
df: org.apache.spark.sql.DataFrame = [age: bigint, username: string]

2. 注册UDF

scala> spark.udf.register("addName",(x:String)=> "Name:"+x)
res9: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,StringType,Some(List(StringType)))

3.创建临时表

scala> df.createOrReplaceTempView("people")

4. 应用UDF

scala> spark.sql("Select addName(name),age from people").show()

二、UDAF

  强类型的Dataset和弱类型的DataFrame都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数。通过继承UserDefinedAggregateFunction来实现用户自定义弱类型聚合函数。
  从Spark3.0版本后,UserDefinedAggregateFunction已经不推荐使用了。可以统一采用强类型聚合函数Aggregator

示例:

需求:计算平均工资

1. UDAF - 弱类型

不再推荐使用UserDefinedAggregateFunction

/*
定义类继承UserDefinedAggregateFunction,并重写其中方法
*/
class MyAveragUDAF extends UserDefinedAggregateFunction {

  // 聚合函数输入参数的数据类型
  def inputSchema: StructType = StructType(Array(StructField("age",IntegerType)))

  // 聚合函数缓冲区中值的数据类型(age,count)
  def bufferSchema: StructType = {
    StructType(Array(StructField("sum",LongType),StructField("count",LongType)))
  }

  // 函数返回值的数据类型
  def dataType: DataType = DoubleType

  // 稳定性:对于相同的输入是否一直返回相同的输出。
  def deterministic: Boolean = true

  // 函数缓冲区初始化
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    // 存年龄的总和
    buffer(0) = 0L
    // 存年龄的个数
    buffer(1) = 0L
  }

  // 更新缓冲区中的数据
  def update(buffer: MutableAggregationBuffer,input: Row): Unit = {
    if (!input.isNullAt(0)) {
      buffer(0) = buffer.getLong(0) + input.getInt(0)
      buffer(1) = buffer.getLong(1) + 1
    }
  }

  // 合并缓冲区
  def merge(buffer1: MutableAggregationBuffer,buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  // 计算最终结果
  def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
}

。。。

//创建聚合函数
var myAverage = new MyAveragUDAF

//在spark中注册聚合函数
spark.udf.register("avgAge",myAverage)

spark.sql("select avgAge(age) from user").show()

推荐使用Aggregator

// TODO 创建UDAF函数
val udaf = new MyAvgAgeUDAF
// TODO 注册到 SparkSQL中
spark.udf.register("avgAge", functions.udaf(udaf))
// TODO 在SQL中使用聚合函数

// 定义用户的自定义聚合函数
spark.sql("select avgAge(age) from user").show

// **************************************************

case class Buff( var sum:Long, var cnt:Long )
// totalage, count
class MyAvgAgeUDAF extends Aggregator[Long, Buff, Double]{
    override def zero: Buff = Buff(0,0)

    override def reduce(b: Buff, a: Long): Buff = {
        b.sum += a
        b.cnt += 1
        b
    }

    override def merge(b1: Buff, b2: Buff): Buff = {
        b1.sum += b2.sum
        b1.cnt += b2.cnt
        b1
    }

    override def finish(reduction: Buff): Double = {
        reduction.sum.toDouble/reduction.cnt
    }

    override def bufferEncoder: Encoder[Buff] = Encoders.product

    override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

2. UDAF - 强类型

//输入数据类型
case class User01(username:String,age:Long)
//缓存类型
case class AgeBuffer(var sum:Long,var count:Long)

/**
  * 定义类继承org.apache.spark.sql.expressions.Aggregator
  * 重写类中的方法
  */
class MyAveragUDAF1 extends Aggregator[User01,AgeBuffer,Double]{
  override def zero: AgeBuffer = {
    AgeBuffer(0L,0L)
  }

  override def reduce(b: AgeBuffer, a: User01): AgeBuffer = {
    b.sum = b.sum + a.age
    b.count = b.count + 1
    b
  }

  override def merge(b1: AgeBuffer, b2: AgeBuffer): AgeBuffer = {
    b1.sum = b1.sum + b2.sum
    b1.count = b1.count + b2.count
    b1
  }

  override def finish(buff: AgeBuffer): Double = {
    buff.sum.toDouble/buff.count
  }
  //DataSet默认额编解码器,用于序列化,固定写法
  //自定义类型就是product自带类型根据类型选择
  override def bufferEncoder: Encoder[AgeBuffer] = {
    Encoders.product
  }

  override def outputEncoder: Encoder[Double] = {
    Encoders.scalaDouble
  }
}

。。。

//封装为DataSet
val ds: Dataset[User01] = df.as[User01]

//创建聚合函数
var myAgeUdaf1 = new MyAveragUDAF1
//将聚合函数转换为查询的列
val col: TypedColumn[User01, Double] = myAgeUdaf1.toColumn

//查询
ds.select(col).show()

需要将聚合类转为列才可以在DataSet中使用,使用方式稍稍有些特殊

你可能感兴趣的:(大数据,spark,scala,big,data,大数据)