Spark 自定义UDAF

前言

需求:业务需求要求求出score的最大值(max),最小值(min),均值(mean),标准差(stddev),中位数。需求的前四个值Spark自带函数可以解决,唯独中位数没有,所以需要自定义一个聚合函数。

实现方法以及代码

自定义函数需要继承UserDefinedAggregateFunction

class MiddleValueUDAF extends UserDefinedAggregateFunction{

// 输入参数的数据类型
  override def inputSchema: StructType = {
     DataTypes.createStructType(util.Arrays
      .asList((DataTypes.createStructField("score",DataTypes.StringType,true))))
  }

  /**
    *
    * 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
    * buffer.getInt(0)获取的是上一次聚合后的值
    * 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合
    * 大聚和发生在reduce端.
    * 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
    */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

buffer.update(0,Integer.valueOf(buffer.get(0).toString)+Integer.valueOf(input.get(0).toString))

    buffer.update(0,buffer.get(0)+","+input.get(0).toString)
  }

//  buffer中的数据类型
  override def bufferSchema: StructType = {
     DataTypes.createStructType(util.Arrays
      .asList((DataTypes.createStructField("summ",DataTypes.StringType,true))))

  }
  /**
    * 合并其他部分结果
    * 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
    * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
    * buffer1.getInt(0) : 大聚合的时候 上一次聚合后的值       
    * buffer2.getInt(0) : 这次计算传入进来的update的结果
    * 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
    * 也可以是一个节点里面的多个executor合并
    */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0,Integer.valueOf(buffer1.get(0).toString)+Integer.valueOf(buffer2.get(0).toString))

    buffer1.update(0,buffer1.get(0)+","+buffer2.get(0).toString)

  }

  //初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0,"")
  }

  // 确保一致性 一般用true,用以标记针对给定的一组输入,UDAF是否总是生成相同的结果
  override def deterministic: Boolean = {

     true
  }

//计算逻辑
  override def evaluate(buffer: Row): Any = {
     val intArray = buffer.get(0).toString.replaceAll(",,",",").substring(1)
    val list = intArray.split(",").map(_.toDouble).toList.sorted
    val len = list.size
    var mid = 0d
    if (len % 2 == 0)
      mid = (list(len / 2 - 1) + list(len / 2)) / 2
    else
      mid = list(len / 2)
    mid
  }


  // 返回值的类型
  override def dataType: DataType = {

     DataTypes.DoubleType
  }

你可能感兴趣的:(Spark 自定义UDAF)