SparkSQL之自定义函数

目录

一 UDF

二 UDAF(弱类型)

三 Aggregator(强类型)

四 Spark早期强类型UDAF操作


用户可以通过 spark.udf 功能添加自定义函数,实现自定义功能。

一 UDF

val df: DataFrame = spark.read.json("data/user.json")
df.createOrReplaceTempView("user")

// SparkSQL自定义函数
spark.udf.register("prefixName", (name: String) => {
	"Name:" + name
})

spark.udf.register("tailSui", (age: Int) => {
	age + "岁"
})

spark.sql(
	"""
		|select
		|tailSui(age),
		|prefixName(username)
		|from
		|user
	""".stripMargin).show()
+------------+--------------------+
|tailSui(age)|prefixName(username)|
+------------+--------------------+
|        20岁|           Name:小王|
|        30岁|           Name:小李|
|        40岁|           Name:小黑|
+------------+--------------------+

二 UDAF(弱类型)

        强类型的Dataset 和弱类型的 DataFrame 都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数。通过继承 UserDefinedAggregateFunction 来实现用户自定义弱类型聚合函数。从Spark3.0 版本后,UserDefinedAggregateFunction 已经不推荐使用了。可以统一采用强类型聚合函数Aggregator

 在Spark3.0之后不推荐使用

def main(args: Array[String]): Unit = {

val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("udf")
val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
import spark.implicits._

val df: DataFrame = spark.read.json("data/user.json")
df.createTempView("user")

spark.udf.register("MyAvg", new MyAggregateBuffer)

spark.sql("select MyAvg(age) from user").show()


spark.stop()
}

class MyAggregateBuffer extends UserDefinedAggregateFunction {
// 输入的数据结构
override def inputSchema: StructType = StructType(
	Array(
		StructField("age", LongType)
	)
)

// 缓冲区,临时计算
override def bufferSchema: StructType = StructType(
	Array(
		StructField("total", LongType),
		StructField("count", LongType)
	)
)

// 函数计算结果输出类型
override def dataType: DataType = LongType

// 函数稳定性
override def deterministic: Boolean = true

// 缓冲区初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
	buffer(0) = 0L
	buffer(1) = 0L
}

// 数据缓冲区数据
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
	buffer.update(0, buffer.getLong(0) + input.getLong(0)) // 年龄总和在第一个位置
	buffer.update(1, buffer.getLong(1) + 1) // 年龄数量
}

// 分区合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
	buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
	buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
}

// 计算平均值
override def evaluate(buffer: Row): Any = {
	buffer.getLong(0) / buffer.getLong(1)
}
}
+----------------------+
|myaggregatebuffer(age)|
+----------------------+
|                    30|
+----------------------+

 

三 Aggregator(强类型)

在Spark3.0之后可以使用

Aggregator是一种类型安全的聚合器,可以用于对数据进行聚合操作。它将输入数据类型和输出数据类型分别定义为泛型,并提供了两个方法,可以对输入数据进行聚合操作并返回最终的输出数据。下面是使用Aggregator的基本步骤:

def main(args: Array[String]): Unit = {
	
	val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("udf")
	val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
	
	val df: DataFrame = spark.read.json("data/user.json")
	df.createTempView("user")
	
	spark.udf.register("MyAvg", functions.udaf(new MyAvgUDAF))
	
	spark.sql("select MyAvg(age) from user").show()
	
	
	spark.stop()
}

// org.apache.spark.sql.expressions.Aggregator
// 输入类型:Long
// 缓冲:Buf
// 输出数据类型:Double
case class Buf(var total: Long, var count: Long)

class MyAvgUDAF extends Aggregator[Long, Buf, Double] {
	
	// 初始值,缓冲区初始化
	override def zero: Buf = {
		Buf(0L, 0L)
	}
	
	// 根据输入数据更新缓冲区
	override def reduce(b: Buf, in: Long): Buf = {
		b.total += in
		b.count += 1
		b
	}
	
	override def merge(b1: Buf, b2: Buf): Buf = {
		b1.total += b2.total
		b1.count += b2.count
		b1
	}
	
	// 计算结果
	override def finish(reduction: Buf): Double = {
		reduction.total.toDouble / reduction.count
	}
	
	// 分布式计算,需要序列化,自定义的类叫Encoders.product
	override def bufferEncoder: Encoder[Buf] = Encoders.product
	
	// 分布式计算,需要序列化,scala的类叫Encoders.scalaxxx
	override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
+--------------+
|myavgudaf(age)|
+--------------+
|          30.0|
+--------------+

四 Spark早期强类型UDAF操作

def main(args: Array[String]): Unit = {
	
	val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("udf")
	val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
	import spark.implicits._
	val df: DataFrame = spark.read.json("data/user.json")
	
	// 早期UDAF强类型聚合函数使用DSL语法
	val ds: Dataset[User] = df.as[User]
	
	// 将UDAF函数转换为查询的列对象
	val udafCol: TypedColumn[User, Double] = new MyAvgUDAF().toColumn
	
	ds.select(udafCol).show()
	spark.stop()
}

case class User(var username: String, var age: Long)

// org.apache.spark.sql.expressions.Aggregator
// 输入类型:Long
// 缓冲:Buf
// 输出数据类型:Double
case class Buf(var total: Long, var count: Long)

class MyAvgUDAF extends Aggregator[User, Buf, Double] {
	
	// 初始值,缓冲区初始化
	override def zero: Buf = {
		Buf(0L, 0L)
	}
	
	// 根据输入数据更新缓冲区
	override def reduce(b: Buf, in: User): Buf = {
		b.total += in.age
		b.count += 1
		b
	}
	
	override def merge(b1: Buf, b2: Buf): Buf = {
		b1.total += b2.total
		b1.count += b2.count
		b1
	}
	
	// 计算结果
	override def finish(reduction: Buf): Double = {
		reduction.total.toDouble / reduction.count
	}
	
	// 分布式计算,需要序列化,自定义的类叫Encoders.product
	override def bufferEncoder: Encoder[Buf] = Encoders.product
	
	// 分布式计算,需要序列化,scala的类叫Encoders.scalaxxx
	override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
+-----------------------------------------------------------------+
|MyAvgUDAF(com.mingyu.spark.sql.Spark02_SparkSQL_aggregator2$User)|
+-----------------------------------------------------------------+
|                                                             30.0|
+-----------------------------------------------------------------+

你可能感兴趣的:(大数据,大数据,spark,scala)