Scala UDAF + Spark Sql实现求中位数

首先用Scala写一个UDAF函数


import scala.collection.mutable.{ArrayBuffer, WrappedArray}
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._


class UDAFMedian extends UserDefinedAggregateFunction {
  
  // 聚合函数的输入数据结构
  def inputSchema: StructType =
    StructType(StructField("value", DoubleType) :: Nil)
  
  // 缓存区数据结构
  def bufferSchema: StructType = StructType(
    StructField("data_list", ArrayType(DoubleType, false)) :: Nil
  )

  // 聚合函数返回值数据类型
  def dataType: DataType = DoubleType

  // 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
  def deterministic: Boolean = true

  // 初始化缓冲区
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = new ArrayBuffer[Double]()
  }

  // 给聚合函数传入一条新数据时的处理逻辑
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    var bufferVal = buffer.getAs[WrappedArray[Double]](0).toBuffer
    bufferVal += input.getAs[Double](0)
    buffer(0) = bufferVal
  }

  // 合并聚合函数缓冲区
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[WrappedArray[Double]](0) ++ buffer2.getAs[WrappedArray[Double]](0)
  }

  // 计算最终结果
  def evaluate(buffer: org.apache.spark.sql.Row): Any = {
    val sortedWindow = buffer.getAs[WrappedArray[Double]](0).sorted.toBuffer
    val windowSize = sortedWindow.size
    if (windowSize % 2 == 0) {
      val index = windowSize / 2
      (sortedWindow(index) + sortedWindow(index - 1)) / 2
    } else {
      sortedWindow((windowSize + 1) / 2 - 1)
    }
  }

}

其次,注册该UDAF并使用

import org.apache.spark.sql.SparkSession
import scala.collection.mutable.ListBuffer

object TestMedian {
  def main(args: Array[String]): Unit = {
 
    val ss: SparkSession = SparkSession.builder().master("local").enableHiveSupport().getOrCreate()
    
    // 注册自定义的UFAF函数,并命名为median
    ss.sqlContext.udf.register("median", new UDAFMedian())

    // 在sql中使用median函数,求中位数
    val sql = "select class, median(score) from scores group by class"
    val rdd = ss.sql(sql).rdd.collect()
    
    // 将sql结果存入ListBuffer
    val result:ListBuffer[String] = new ListBuffer[String]()
    for (i <- 0 to rdd.length - 1) {
      val line: StringBuffer = new StringBuffer()
      for (j <- 0 to rdd(i).length - 1) {
        val value = rdd(i)(j)
        if (Option(value) == None) {
          line.append("")
        } else {
          line.append(value.toString)
        }
        if (j < rdd(i).length - 1) {
          line.append(",")
        }
      }
      result.append(line.toString)
    }
  }
}

官方UDAF示例参考: https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html

你可能感兴趣的:(Scala UDAF + Spark Sql实现求中位数)