// 注册函数
spark.udf.register("prefix1", (name: String) => {
"Name:" + name
})
// 使用函数
spark.sql("select *,prefix1(name) from users").show()
// 1 定义UDAF(弱类型、3.0.0之前得版本可以使用,没标记过时)
package com.shufang.rdd_ds_df
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}
class MyUDAF extends UserDefinedAggregateFunction {
// IN
override def inputSchema: StructType = {
StructType(
Array(
StructField("age", LongType)
)
)
}
// MIDDLE 缓冲区类型
override def bufferSchema: StructType = {
StructType(
Array(
StructField("total", LongType),
StructField("count", LongType)
)
)
}
// OUT
override def dataType: DataType = LongType
// 函数的稳定性
override def deterministic: Boolean = {
true
}
// 缓冲器的初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
/*buffer(0) = 0L
buffer(1) = 0L*/
buffer.update(0, 0L)
buffer.update(1, 0L)
}
// 根据输入的值更新缓冲区
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0, buffer.getLong(0) + input.getLong(0))
buffer.update(1, buffer.getLong(1) + 1)
}
// 合并多个缓冲区
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0,buffer1.getLong(0) + buffer2.getLong(0))
buffer1.update(1,buffer1.getLong(1) + buffer2.getLong(1))
}
// 计算平均值
override def evaluate(buffer: Row): Any = {
buffer.getLong(0)/buffer.getLong(1)
}
}
// 2 注册&使用
spark.udf.register("ageAvg", new MyUDAF)
spark.sql("select ageAvg(id) as av from users").show()
// 1 声明并实现
package com.shufang.rdd_ds_df
import org.apache.spark.sql.{Encoder, Encoders, Row}
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
/**
* Aggregator[IN, BUF, OUT] should now be registered as a UDF" + via the functions.udaf(agg) method.", "3.0.0"
*/
case class Buff(var total:Long ,var count:Long)
class MyUDAF1 extends Aggregator[Long,Buff,Long] {
//缓冲区初始化
override def zero: Buff = Buff(0L,0L)
//将进来的元素与缓冲区进行合并
override def reduce(b: Buff, a: Long): Buff = {
b.count +=1
b.total += a
b
}
//合并多个缓冲区
override def merge(b1: Buff, b2: Buff): Buff = {
b1.count = b1.count + b2.count
b1.total = b1.total + b2.total
b1
}
// 计算最终结果
override def finish(buff: Buff): Long = {
buff.total/buff.count
}
// 定义序列化编码器
override def bufferEncoder: Encoder[Buff] = Encoders.product
//定义序列化编码器
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
// 2 注册并使用,注册方式不一样
spark.udf.register("ageAvg", functions.udaf(new MyUDAF1()))
spark.sql("select ageAvg(id) as av from users").show()
如果是3.0.0之前的版本需要使用强类型,需要结合DSL sparkSQL的领域语言
// 1 声明,相当于DS的每一行相当于传入的参数
package com.shufang.rdd_ds_df
import org.apache.spark.sql.{Encoder, Encoders, Row}
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
/**
* Aggregator[IN, BUF, OUT] should now be registered as a UDF" + via the functions.udaf(agg) method.", "3.0.0"
*/
//case class Buff(var total:Long ,var count:Long)
class MyUDAF2 extends Aggregator[User,Buff,Long] {
//缓冲区初始化
override def zero: Buff = Buff(0L,0L)
override def reduce(b: Buff, a: User): Buff = {
b.count +=1
b.total += a.id
b
}
override def merge(b1: Buff, b2: Buff): Buff = {
b1.count = b1.count + b2.count
b1.total = b1.total + b2.total
b1
}
override def finish(buff: Buff): Long = {
buff.total/buff.count
}
override def bufferEncoder: Encoder[Buff] = Encoders.product
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
// 2 使用
val column: TypedColumn[User, Long] = new MyUDAF2().toColumn
val ds: Dataset[User] = df.as[User]
ds.select(column).show()