spark自定义UDAF函数

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType}

class GroupConcatDistinct extends UserDefinedAggregateFunction{
  //UDAF:输入数据类型为String
  override def inputSchema: StructType = StructType(List(StructField("cityInfo",StringType,true)))

  //缓冲区类型
  override def bufferSchema: StructType = StructType(List(StructField("buffCityInfo",StringType,true)))

  //输出数据类型
  override def dataType: DataType = StringType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = ""
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    var buffCityInfo = buffer.getString(0)
    val cityInfo = input.getString(0)

    if(!buffCityInfo.contains(cityInfo)){
      if("".equals(buffCityInfo)){
        buffCityInfo += cityInfo
      }else{
        buffCityInfo += "," + cityInfo
      }
      buffer.update(0,buffCityInfo)
    }
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

    var buffCityInfo1 = buffer1.getString(0)
    val buffCityInfo2 = buffer2.getString(0)

    for(cityInfo <- buffCityInfo2.split(",")){
          if(!buffCityInfo1.contains(cityInfo)){
            if("".equals(buffCityInfo1)){
              buffCityInfo1 += cityInfo
            }else{
              buffCityInfo1 += "," + cityInfo
            }
          }
    }

    buffer1.update(0,buffCityInfo1)

  }

  override def evaluate(buffer: Row): Any = {
    buffer.getString(0)
  }
}

 

你可能感兴趣的:(spark,spark)