SparkSQL 之自定义函数UDAF

需求:计算1-10的几何平均数
需要继承UserDefinedAggregateFunction 并重写方法 含义见注释

package cn.UDAF

import java.lang
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Dataset, Row, SparkSession}

/**
  * 计算1-10的几何平均数
  *
  * @Author xiaohuli
  * @CreateDate 2019/2/6
  */
object UdafTest {
    def main(args: Array[String]): Unit = {
        val spark = SparkSession.builder()
          .appName("UdafTest")
          .master("local[4]")
          .getOrCreate()
        //创建Dataset
        val range: Dataset[lang.Long] = spark.range(1, 11)
        //注册函数
        spark.udf.register("gm", GeoMean)
        //注册临时视图
        range.createTempView("v_range")
        //执行sql语句
        val result = spark.sql("SELECT gm(id) result FROM v_range")
        result.show()
        //关闭资源
        spark.stop()
    }
}

/**
  * 自定义UDAF 计算几何平均数
  */
object GeoMean extends UserDefinedAggregateFunction {
    //输入数据类型
    override def inputSchema: StructType = StructType(List(
        //给输入变量定义变量名 和 数据类型
        StructField("value", DoubleType)
    ))

    //中间结果的数据类型
    override def bufferSchema: StructType = StructType(List(
        //相乘之后的积
        StructField("product", DoubleType),
        //参与运算的数字总数
        StructField("counts", LongType)
    ))

    //最终返回结果的数据类型
    override def dataType: DataType = DoubleType

    //确保一致性 一般用true
    override def deterministic: Boolean = true

    //指定初始化的值
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
        //乘积的初始值 1.0
        buffer(0) = 1.0
        //乘数总数的初始值 0L
        buffer(1) = 0L
    }

    //局部聚合 就是每个分区内的运算
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        //将输入值与中间值相乘
        buffer(0) = buffer.getDouble(0) * input.getDouble(0)
        //参与运算的数+1
        buffer(1) = buffer.getLong(1) + 1L
    }

    //全局聚合
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        //将每个分区的结果进行相乘
        buffer1(0) = buffer1.getDouble(0) * buffer2.getDouble(0)
        //将每个分区内的次数相加为总次数
        buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
    }

    //计算最终结果
    override def evaluate(buffer: Row): Double = {
        math.pow(buffer.getDouble(0), 1.toDouble / buffer.getLong(1))
    }
}

你可能感兴趣的:(Spark)