SparkSQL允许用户可以通过spark.udf功能添加自定义函数,实现自定义功能。
scala> val df = spark.read.json("data/user.json")
df: org.apache.spark.sql.DataFrame = [age: bigint, username: string]
scala> spark.udf.register("addName",(x:String)=> "Name:"+x)
res9: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,StringType,Some(List(StringType)))
scala> df.createOrReplaceTempView("people")
scala> spark.sql("Select addName(name),age from people").show()
强类型的Dataset和弱类型的DataFrame都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数。通过继承UserDefinedAggregateFunction来实现用户自定义弱类型聚合函数。
从Spark3.0版本后,UserDefinedAggregateFunction已经不推荐使用了。可以统一采用强类型聚合函数Aggregator
示例:
需求:计算平均工资
不再推荐使用
UserDefinedAggregateFunction
/*
定义类继承UserDefinedAggregateFunction,并重写其中方法
*/
class MyAveragUDAF extends UserDefinedAggregateFunction {
// 聚合函数输入参数的数据类型
def inputSchema: StructType = StructType(Array(StructField("age",IntegerType)))
// 聚合函数缓冲区中值的数据类型(age,count)
def bufferSchema: StructType = {
StructType(Array(StructField("sum",LongType),StructField("count",LongType)))
}
// 函数返回值的数据类型
def dataType: DataType = DoubleType
// 稳定性:对于相同的输入是否一直返回相同的输出。
def deterministic: Boolean = true
// 函数缓冲区初始化
def initialize(buffer: MutableAggregationBuffer): Unit = {
// 存年龄的总和
buffer(0) = 0L
// 存年龄的个数
buffer(1) = 0L
}
// 更新缓冲区中的数据
def update(buffer: MutableAggregationBuffer,input: Row): Unit = {
if (!input.isNullAt(0)) {
buffer(0) = buffer.getLong(0) + input.getInt(0)
buffer(1) = buffer.getLong(1) + 1
}
}
// 合并缓冲区
def merge(buffer1: MutableAggregationBuffer,buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// 计算最终结果
def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
}
。。。
//创建聚合函数
var myAverage = new MyAveragUDAF
//在spark中注册聚合函数
spark.udf.register("avgAge",myAverage)
spark.sql("select avgAge(age) from user").show()
推荐使用
Aggregator
// TODO 创建UDAF函数
val udaf = new MyAvgAgeUDAF
// TODO 注册到 SparkSQL中
spark.udf.register("avgAge", functions.udaf(udaf))
// TODO 在SQL中使用聚合函数
// 定义用户的自定义聚合函数
spark.sql("select avgAge(age) from user").show
// **************************************************
case class Buff( var sum:Long, var cnt:Long )
// totalage, count
class MyAvgAgeUDAF extends Aggregator[Long, Buff, Double]{
override def zero: Buff = Buff(0,0)
override def reduce(b: Buff, a: Long): Buff = {
b.sum += a
b.cnt += 1
b
}
override def merge(b1: Buff, b2: Buff): Buff = {
b1.sum += b2.sum
b1.cnt += b2.cnt
b1
}
override def finish(reduction: Buff): Double = {
reduction.sum.toDouble/reduction.cnt
}
override def bufferEncoder: Encoder[Buff] = Encoders.product
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
//输入数据类型
case class User01(username:String,age:Long)
//缓存类型
case class AgeBuffer(var sum:Long,var count:Long)
/**
* 定义类继承org.apache.spark.sql.expressions.Aggregator
* 重写类中的方法
*/
class MyAveragUDAF1 extends Aggregator[User01,AgeBuffer,Double]{
override def zero: AgeBuffer = {
AgeBuffer(0L,0L)
}
override def reduce(b: AgeBuffer, a: User01): AgeBuffer = {
b.sum = b.sum + a.age
b.count = b.count + 1
b
}
override def merge(b1: AgeBuffer, b2: AgeBuffer): AgeBuffer = {
b1.sum = b1.sum + b2.sum
b1.count = b1.count + b2.count
b1
}
override def finish(buff: AgeBuffer): Double = {
buff.sum.toDouble/buff.count
}
//DataSet默认额编解码器,用于序列化,固定写法
//自定义类型就是product自带类型根据类型选择
override def bufferEncoder: Encoder[AgeBuffer] = {
Encoders.product
}
override def outputEncoder: Encoder[Double] = {
Encoders.scalaDouble
}
}
。。。
//封装为DataSet
val ds: Dataset[User01] = df.as[User01]
//创建聚合函数
var myAgeUdaf1 = new MyAveragUDAF1
//将聚合函数转换为查询的列
val col: TypedColumn[User01, Double] = myAgeUdaf1.toColumn
//查询
ds.select(col).show()
需要将聚合类转为列才可以在DataSet中使用,使用方式稍稍有些特殊