spark如何写自定义聚合函数UDAF

spark自定义聚合函数需要继承一个抽象类UserDefinedAggregateFunction,并重写以下属性和方法

1. inputSchema:函数的参数列表,不过需要写成StructType的格式,例如:

override def inputSchema:StructType = StructType(Array(StructField("age",IntegerType)))

一个参数就是一个StructField,"age"代表参数名;IntegerType是参数类型,也就是int类型,但不能写scala的Int,必须是sparksql的数据类型,具体支持哪些类型,可见org.apache.spark.sql.types包下的类,

 

2.bufferSchema:中间结果的类型,比如求和时,a、b、c相加,需要先计算a+b并保存结果ab,然后计算ab+c,这个ab就是中间结果。如果是求平均数,存储总和以及计数,总和及计数就是中间结果,例子如下和1类似,不再赘述。

override def bufferSchema:StructType = StructType(Array(StructField("count",IntegerType),
                                                      StructField("sum_age",IntegerType)))

3.dataType:返回值结果类型,显示是DataType,也就是org.apache.spark.sql.types包下的那些类,例:

override def dataType:DataType = IntegerType

4.deterministic: 结果是否是确定性的,即相同的输入,是否一定会有相同的输出。例:

override def deterministic:Boolean = true

5.initialize:初始化中间结果,例如求和函数,开始计算前需要先将中间结果赋值为0。例:

  override def initialize(buffer:MutableAggregationBuffer):Unit = {buffer(0)=0}

buffer是中间结果,是Row类的子类。

6.update(buffer: MutableAggregationBuffer, input: Row):更新中间结果,input是dataframe的一行,buffer是整个分片遍历过来的中间结果。例:

override def update(buffer:MutableAggregationBuffer,input:Row):Unit={
    buffer(0) = buffer.getInt(0) + input.getInt(0)
  }

7.merge(buffer1:MutableAggregationBuffer,buffer2:Row):分片的合并,buffer2一个分片的中间结果,buffer1是整个合并过程的中间结果,例:

override def merge(buffer1:MutableAggregationBuffer,buffer2:Row):Unit={
    buffer1(0) = buffer1.getInt(0)+buffer2.getInt(0)
  }

8.evaluate(buffer:Row):返回函数结果,buffer是7的合并过程的中间结果buffer1遍历所有分片结束后的结果。例

override def evaluate(buffer:Row):Any = buffer.getInt(0)

 

下面是我写的一个求平均数的UDAF

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType}


class MyUDAF extends UserDefinedAggregateFunction{


  override def inputSchema:StructType = StructType(Array(StructField("age",IntegerType)))

  override def bufferSchema:StructType = StructType(Array(StructField("count",IntegerType),StructField("ages",IntegerType)))

  override def dataType:DataType = IntegerType

  override def deterministic:Boolean = true

  override def initialize(buffer:MutableAggregationBuffer):Unit = {buffer(0)=0;buffer(1)=0}


  override def update(buffer:MutableAggregationBuffer,input:Row):Unit={
    buffer(0) = buffer.getInt(0) + 1
    buffer(1) = buffer.getInt(1) + input.getInt(0)
  }

  override def merge(buffer1:MutableAggregationBuffer,buffer2:Row):Unit={
    buffer1(0) = buffer1.getInt(0)+buffer2.getInt(0)
    buffer1(1) = buffer1.getInt(1)+buffer2.getInt(1)
  }

  override def evaluate(buffer:Row):Any = buffer.getInt(1)/buffer.getInt(0)

}

使用方式如下:

    val spark = SparkSession
      .builder
      .master("local")
      .appName("avg")
      .getOrCreate()
    spark.udf.register("avgage",new MyUDAF)
    import  spark.implicits._
    val ageDF = spark.createDataFrame(Seq((22,1),(24,1),(11,2),(15,2)))
      .toDF("age","class_id")
    ageDF.registerTempTable("bigDataTable")
    spark.sql("select avgage(age) from bigDataTable group by class_id").show()

输出:

+------------+
|myudaf(age)|
+------------+
|          23|
|          13|
+------------+

你可能感兴趣的:(spark)