UDF:用户自定义函数,一输入一输出
UDAF:用户自定义聚合函数,多输入一输出
(1)UDAF代码:
package _0728sql
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
*
*/
object Avg_UDAF extends UserDefinedAggregateFunction{
//(a)
override def inputSchema: StructType = {
/**
* 给定UDAF函数的输入参数类型(schema)
* iv代表input value
*/
StructType(Array(
StructField("iv",DoubleType)
))
}
//(b)
override def bufferSchema: StructType = {
//给定缓存数据的数据类型 avg = totalValue / totalCount
//tv:total value
//tc:total count
StructType(Array(
StructField("tv",DoubleType),
StructField("tc",IntegerType)
))
}
//(c)
override def dataType: DataType = {
//给定返回的数据类型
DoubleType
}
//(d)
override def deterministic: Boolean = {
//给定多次运行是否允许返回结果不一致(模糊查询) true表示不允许
//一般都为true
true
}
//(e)
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//缓存数据的初始值
buffer.update(0,0.0)
buffer.update(1,0)
}
//(f)
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//对于每一条输入数据(当前分组的),更新buffer中的值
//1、获取输入数据
val iv = input.getDouble(0)
//2、获取缓存区数据
val tv = buffer.getDouble(0)
val tc = buffer.getInt(1)
//3、更新缓存区数据
buffer.update(0,tv + iv)
buffer.update(1,tc + 1)
}
//(g)
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//当两个分区的结果需要进行合并的时候,会调用该merge方法
//1、获取buffer1的数据
val tv1 = buffer1.getDouble(0)
val tc1 = buffer1.getInt(1)
//2、获取buffer2的数据
val tv2 = buffer2.getDouble(0)
val tc2 = buffer2.getInt(1)
/*
3、然后把数据更新到buffer1当中去,不能更新到buffer2
因为MutableAggregationBuffer这个数据类型才是可以更新的数据类型实现了update方法
*/
buffer1.update(0,tv1+tv2)
buffer1.update(1,tc1+tc2)
}
//(h)
override def evaluate(buffer: Row): Any = {
val tv = buffer.getDouble(0)
val tc = buffer.getInt(1)
tv/tc
}
}
(2)总代码
package _0728sql
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{SQLContext, SparkSession}
//import _0728sql.Avg_UDAF
/**
*
*/
object UDFandUDAF extends App{
/**
*
*/
val conf = new SparkConf()
.setMaster("local[*]")
.setAppName("UDFandUDAF")
//这个方法是一个锁的机制,通过这个方法可以保证只有一个上下文
val sc = SparkContext.getOrCreate(conf)
//如果不需要用hive就不要用hivecontext,使用sqlcontext就可以了
val sqlContext = new SQLContext(sc)
//1.UDF
// udf 保留小数点后两位
// format_double是函数名称,后面是匿名函数
sqlContext.udf.register("format_double",(value:Double)=>{
import java.math.BigDecimal
val bd=new BigDecimal(value)
bd.setScale(2,BigDecimal.ROUND_HALF_UP).doubleValue()
})
import sqlContext.implicits._
sc.parallelize(Array(
(1, 1234),
(1, 45212),
(1, 22125),
(1, 12521),
(1, 12352),
(2, 52352),
(2, 2232),
(2, 12521),
(2, 12323),
(3, 2253),
(3, 2233),
(3, 22558),
(4, 252),
(4, 235),
(5, 523)
)).toDF("id", "sal").registerTempTable("tmp_emp")
sqlContext.sql(
"""
|select
|id,AVG(sal) as sal1,
|format_double(AVG(sal)) as sal3
|from tmp_emp
|group by id
""".stripMargin).show
//2.UDAF
sqlContext.udf.register("self_avg",Avg_UDAF)
sqlContext.sql(
"""
|select
|id,AVG(sal) as sal1,
|format_double(AVG(sal)) as sal3,
|format_double(self_avg(sal)) as sal4
|from tmp_emp
|group by id
""".stripMargin).show
}