目录
一 UDF
二 UDAF(弱类型)
三 Aggregator(强类型)
四 Spark早期强类型UDAF操作
用户可以通过 spark.udf 功能添加自定义函数,实现自定义功能。
val df: DataFrame = spark.read.json("data/user.json")
df.createOrReplaceTempView("user")
// SparkSQL自定义函数
spark.udf.register("prefixName", (name: String) => {
"Name:" + name
})
spark.udf.register("tailSui", (age: Int) => {
age + "岁"
})
spark.sql(
"""
|select
|tailSui(age),
|prefixName(username)
|from
|user
""".stripMargin).show()
+------------+--------------------+
|tailSui(age)|prefixName(username)|
+------------+--------------------+
| 20岁| Name:小王|
| 30岁| Name:小李|
| 40岁| Name:小黑|
+------------+--------------------+
强类型的Dataset 和弱类型的 DataFrame 都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数。通过继承 UserDefinedAggregateFunction 来实现用户自定义弱类型聚合函数。从Spark3.0 版本后,UserDefinedAggregateFunction 已经不推荐使用了。可以统一采用强类型聚合函数Aggregator
在Spark3.0之后不推荐使用
def main(args: Array[String]): Unit = {
val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("udf")
val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
import spark.implicits._
val df: DataFrame = spark.read.json("data/user.json")
df.createTempView("user")
spark.udf.register("MyAvg", new MyAggregateBuffer)
spark.sql("select MyAvg(age) from user").show()
spark.stop()
}
class MyAggregateBuffer extends UserDefinedAggregateFunction {
// 输入的数据结构
override def inputSchema: StructType = StructType(
Array(
StructField("age", LongType)
)
)
// 缓冲区,临时计算
override def bufferSchema: StructType = StructType(
Array(
StructField("total", LongType),
StructField("count", LongType)
)
)
// 函数计算结果输出类型
override def dataType: DataType = LongType
// 函数稳定性
override def deterministic: Boolean = true
// 缓冲区初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
// 数据缓冲区数据
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0, buffer.getLong(0) + input.getLong(0)) // 年龄总和在第一个位置
buffer.update(1, buffer.getLong(1) + 1) // 年龄数量
}
// 分区合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
}
// 计算平均值
override def evaluate(buffer: Row): Any = {
buffer.getLong(0) / buffer.getLong(1)
}
}
+----------------------+
|myaggregatebuffer(age)|
+----------------------+
| 30|
+----------------------+
在Spark3.0之后可以使用
Aggregator是一种类型安全的聚合器,可以用于对数据进行聚合操作。它将输入数据类型和输出数据类型分别定义为泛型,并提供了两个方法,可以对输入数据进行聚合操作并返回最终的输出数据。下面是使用Aggregator的基本步骤:
def main(args: Array[String]): Unit = {
val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("udf")
val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
val df: DataFrame = spark.read.json("data/user.json")
df.createTempView("user")
spark.udf.register("MyAvg", functions.udaf(new MyAvgUDAF))
spark.sql("select MyAvg(age) from user").show()
spark.stop()
}
// org.apache.spark.sql.expressions.Aggregator
// 输入类型:Long
// 缓冲:Buf
// 输出数据类型:Double
case class Buf(var total: Long, var count: Long)
class MyAvgUDAF extends Aggregator[Long, Buf, Double] {
// 初始值,缓冲区初始化
override def zero: Buf = {
Buf(0L, 0L)
}
// 根据输入数据更新缓冲区
override def reduce(b: Buf, in: Long): Buf = {
b.total += in
b.count += 1
b
}
override def merge(b1: Buf, b2: Buf): Buf = {
b1.total += b2.total
b1.count += b2.count
b1
}
// 计算结果
override def finish(reduction: Buf): Double = {
reduction.total.toDouble / reduction.count
}
// 分布式计算,需要序列化,自定义的类叫Encoders.product
override def bufferEncoder: Encoder[Buf] = Encoders.product
// 分布式计算,需要序列化,scala的类叫Encoders.scalaxxx
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
+--------------+
|myavgudaf(age)|
+--------------+
| 30.0|
+--------------+
def main(args: Array[String]): Unit = {
val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("udf")
val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
import spark.implicits._
val df: DataFrame = spark.read.json("data/user.json")
// 早期UDAF强类型聚合函数使用DSL语法
val ds: Dataset[User] = df.as[User]
// 将UDAF函数转换为查询的列对象
val udafCol: TypedColumn[User, Double] = new MyAvgUDAF().toColumn
ds.select(udafCol).show()
spark.stop()
}
case class User(var username: String, var age: Long)
// org.apache.spark.sql.expressions.Aggregator
// 输入类型:Long
// 缓冲:Buf
// 输出数据类型:Double
case class Buf(var total: Long, var count: Long)
class MyAvgUDAF extends Aggregator[User, Buf, Double] {
// 初始值,缓冲区初始化
override def zero: Buf = {
Buf(0L, 0L)
}
// 根据输入数据更新缓冲区
override def reduce(b: Buf, in: User): Buf = {
b.total += in.age
b.count += 1
b
}
override def merge(b1: Buf, b2: Buf): Buf = {
b1.total += b2.total
b1.count += b2.count
b1
}
// 计算结果
override def finish(reduction: Buf): Double = {
reduction.total.toDouble / reduction.count
}
// 分布式计算,需要序列化,自定义的类叫Encoders.product
override def bufferEncoder: Encoder[Buf] = Encoders.product
// 分布式计算,需要序列化,scala的类叫Encoders.scalaxxx
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
+-----------------------------------------------------------------+
|MyAvgUDAF(com.mingyu.spark.sql.Spark02_SparkSQL_aggregator2$User)|
+-----------------------------------------------------------------+
| 30.0|
+-----------------------------------------------------------------+