SparkSQL自定义函数(实现几何平均数)

SparkSQL-自定义聚合函数 (实现几何平均数)


->创建SparkSessionparkSession

->创建自定义函数
    -1、继承UserDefinedAggregateFunction
    -2、重写下面的方法    
        inputSchema        -输入数据的类型
        bufferSchema       -产生中间结果的数据类型
        dataType           -最终返回的结果类型
        deterministic      -确保一致性
        initialize         -指定初始值
        update             -每有一条数据参与运算就更新一下中间结果
        merge              -全局聚合
        evaluate           -计算最终结果
        
    !:StructField        -哪些列,啥类型
->实例化自定义函数,并注册自定义函数(spark.udf.register)

 

代码:

object Geometric {
  def main(args: Array[String]): Unit = {
    //创建sparkSession
    val sparkSession: SparkSession = SparkSession.builder().appName("Geometric").master("local[*]").getOrCreate()
    //造数据 1~10
    val range: Dataset[lang.Long] = sparkSession.range(1, 11)

    //实例化Geom类
    val geomean = new Geom
    //注册视图
    range.createTempView("v_range")
    //注册自定义函数
    sparkSession.udf.register("ge", geomean)
    //执行sparkSql语句
    val res: DataFrame = sparkSession.sql("select ge(id) result from v_range")

    res.show()

    sparkSession.stop()
  }
}

class Geom extends UserDefinedAggregateFunction {
  //输入类型
  override def inputSchema: StructType = StructType(List(StructField("value", DoubleType)))

  //中间数据
  override def bufferSchema: StructType = StructType(List(
    StructField("product", DoubleType),
    StructField("counts", LongType)
  ))

  //最终返回结果类型
  override def dataType: DataType = DoubleType

  //确保一致性
  override def deterministic: Boolean = true

  //指定初始值
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 1.0
    buffer(1) = 0L
  }

  //每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getDouble(0) * input.getDouble(0)
    buffer(1) = buffer.getLong(1) + 1L
  }

  //全局聚合
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getDouble(0) * buffer2.getDouble(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)

  }

  override def evaluate(buffer: Row): Double = {

    math.pow(buffer.getDouble(0), 1.toDouble / buffer.getLong(1))

  }
}

 

你可能感兴趣的:(大数据开发)