Spark编写UDAF自定义函数

Hive中有UDF与UDAF,Spark中对UDF支持较早,UDAF:User Defined Aggregate Function。用户自定义聚合函数,是直到Spark 1.5.x才引入的最新特性。

UDAF,则可以针对多行输入,进行聚合计算。

编写一个实现平均数的UDAF


1、自定义UDAF,需要extends  org.apache.spark.sql.expressions.UserDefinedAggregateFunction,并实现接口中的8个方法

package com.spark.sql

import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.StructField

/**
 * @author Administrator
 */
class NumsAvg extends UserDefinedAggregateFunction {
  def inputSchema: org.apache.spark.sql.types.StructType =
    StructType(StructField("nums", DoubleType) :: Nil)

  def bufferSchema: StructType = StructType(
    StructField("cnt", LongType) ::
      StructField("avg", DoubleType) :: Nil)

  def dataType: DataType = DoubleType

  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0l
    buffer(1) = 0.0
  }

  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Long](0) + 1
    buffer(1) = buffer.getAs[Double](1) + input.getAs[Double](0)
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
    buffer1(1) = buffer1.getAs[Double](1) + buffer2.getAs[Double](1)
  }

  def evaluate(buffer: Row): Any = {
    val t = buffer.getDouble(1) / buffer.getLong(0)
    f"$t%1.5f".toDouble
  }
}

2、使用自定义的UDAF测试

分别使用原生的avg()函数及自定义的numsAvg

package com.spark.sql

import org.apache.spark.sql.SQLContext
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.DoubleType

/**
 * @author Administrator
 */
object NumsAvgTest {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("UDAF").setMaster("local")

    val sc = new SparkContext(conf)

    val sqlContext = new SQLContext(sc)

    import org.apache.spark.sql.functions._

    val nums = List(4.5, 2.4, 1.8)
    val numsRDD = sc.parallelize(nums, 1);

    val numsRowRDD = numsRDD.map { x => Row(x) }

    val structType = StructType(Array(StructField("num", DoubleType, true)))

    val numsDF = sqlContext.createDataFrame(numsRowRDD, structType)

    numsDF.registerTempTable("numtest")
    sqlContext.sql("select avg(num) from numtest ").collect().foreach { x => println(x) }
    
    sqlContext.udf.register("numsAvg", new NumsAvg)
    sqlContext.sql("select numsAvg(num) from numtest ").collect().foreach { x => println(x) }
  }
}

测试结果对比

原生的avg()

[2.9000000000000004]

自定义的numsAvg

[2.9]

自定义的函数可以自己控制好数据的精度


3、UDAF的编写实现UserDefinedAggregateFunction接口,使用时注册一下即可。


你可能感兴趣的:(spark,函数,udaf)