spark自定义聚合函数需要继承一个抽象类UserDefinedAggregateFunction,并重写以下属性和方法
1. inputSchema:函数的参数列表,不过需要写成StructType的格式,例如:
override def inputSchema:StructType = StructType(Array(StructField("age",IntegerType)))
一个参数就是一个StructField,"age"代表参数名;IntegerType是参数类型,也就是int类型,但不能写scala的Int,必须是sparksql的数据类型,具体支持哪些类型,可见org.apache.spark.sql.types包下的类,
2.bufferSchema:中间结果的类型,比如求和时,a、b、c相加,需要先计算a+b并保存结果ab,然后计算ab+c,这个ab就是中间结果。如果是求平均数,存储总和以及计数,总和及计数就是中间结果,例子如下和1类似,不再赘述。
override def bufferSchema:StructType = StructType(Array(StructField("count",IntegerType),
StructField("sum_age",IntegerType)))
3.dataType:返回值结果类型,显示是DataType,也就是org.apache.spark.sql.types包下的那些类,例:
override def dataType:DataType = IntegerType
4.deterministic: 结果是否是确定性的,即相同的输入,是否一定会有相同的输出。例:
override def deterministic:Boolean = true
5.initialize:初始化中间结果,例如求和函数,开始计算前需要先将中间结果赋值为0。例:
override def initialize(buffer:MutableAggregationBuffer):Unit = {buffer(0)=0}
buffer是中间结果,是Row类的子类。
6.update(buffer: MutableAggregationBuffer, input: Row):更新中间结果,input是dataframe的一行,buffer是整个分片遍历过来的中间结果。例:
override def update(buffer:MutableAggregationBuffer,input:Row):Unit={
buffer(0) = buffer.getInt(0) + input.getInt(0)
}
7.merge(buffer1:MutableAggregationBuffer,buffer2:Row):分片的合并,buffer2一个分片的中间结果,buffer1是整个合并过程的中间结果,例:
override def merge(buffer1:MutableAggregationBuffer,buffer2:Row):Unit={
buffer1(0) = buffer1.getInt(0)+buffer2.getInt(0)
}
8.evaluate(buffer:Row):返回函数结果,buffer是7的合并过程的中间结果buffer1遍历所有分片结束后的结果。例
override def evaluate(buffer:Row):Any = buffer.getInt(0)
下面是我写的一个求平均数的UDAF
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType}
class MyUDAF extends UserDefinedAggregateFunction{
override def inputSchema:StructType = StructType(Array(StructField("age",IntegerType)))
override def bufferSchema:StructType = StructType(Array(StructField("count",IntegerType),StructField("ages",IntegerType)))
override def dataType:DataType = IntegerType
override def deterministic:Boolean = true
override def initialize(buffer:MutableAggregationBuffer):Unit = {buffer(0)=0;buffer(1)=0}
override def update(buffer:MutableAggregationBuffer,input:Row):Unit={
buffer(0) = buffer.getInt(0) + 1
buffer(1) = buffer.getInt(1) + input.getInt(0)
}
override def merge(buffer1:MutableAggregationBuffer,buffer2:Row):Unit={
buffer1(0) = buffer1.getInt(0)+buffer2.getInt(0)
buffer1(1) = buffer1.getInt(1)+buffer2.getInt(1)
}
override def evaluate(buffer:Row):Any = buffer.getInt(1)/buffer.getInt(0)
}
使用方式如下:
val spark = SparkSession
.builder
.master("local")
.appName("avg")
.getOrCreate()
spark.udf.register("avgage",new MyUDAF)
import spark.implicits._
val ageDF = spark.createDataFrame(Seq((22,1),(24,1),(11,2),(15,2)))
.toDF("age","class_id")
ageDF.registerTempTable("bigDataTable")
spark.sql("select avgage(age) from bigDataTable group by class_id").show()
输出:
+------------+
|myudaf(age)|
+------------+
| 23|
| 13|
+------------+