Spark 2.3.0 用户自定义聚合函数UserDefinedAggregateFunction和Aggregator

Spark 2.3.0 用户自定义聚合函数UserDefinedAggregateFunction和Aggregator 

一、无类型的用户自定于聚合函数(Untyped User-Defined Aggregate Functions)

实现无类型的用户自定于聚合函数需要继承抽象类UserDefinedAggregateFunction,并重写该类的8个函数。我们以计算数据类型为Double的列score的平均值为例进行详细说明。score来源于数据文件itemdata.data,格式如下:

0162381440670851711,4,7.0
0162381440670851711,11,4.0
0162381440670851711,32,1.0
0162381440670851711,176,27.0
0162381440670851711,183,11.0
0162381440670851711,184,5.0
0162381440670851711,207,9.0
0162381440670851711,256,3.0
0162381440670851711,258,4.0
0162381440670851711,259,16.0
0162381440670851711,260,8.0
0162381440670851711,261,18.0
0162381440670851711,301,1.0

第一列为user_id,第二列为item_id,第三列为score。

1、inputSchema

        定义输入数据的Schema,要求类型是StructType,它的参数是由StructField类型构成的数组。比如这里要定义score列的Schema,首先使用StructField声明score列的名字score_column,数据类型为DoubleType。这里输入只有score这一列,所以StructField构成的数组只有一个元素。如下:

override def inputSchema: StructType = StructType(StructField("score_column",DoubleType)::Nil)

::是Scala中的操作符与Nil空集合操作后生成一个数组。

 

2、bufferSchema

        事实上,计算score的平均值时,需要用到score的总和sum以及score的总个数count这样的中间数据,那么就使用bufferSchema来定义它们。如下:

override def bufferSchema: StructType = StructType(StructField("sum",DoubleType)::StructField("count",LongType)::Nil)

这里StructField类型的数组就有两个元素:数据类型为DoubleType的sum和数据类型为LongType类型的count。

 

3、dataType

        我们需要对自定义聚合函数的最终数据类型进行说明,使用dataType函数。比如计算出的平均score是Double类型,如下定义:

override def dataType: DataType = DoubleType

 

4、deterministic

deterministic函数用于对输入数据进行一致性检验,是一个布尔值,当为true时,表示对于同样的输入会得到同样的输出。因为对于同样的score输入,肯定要得到相同的score平均值,所以定义为true,如下:

override def deterministic: Boolean = true

 

5、initialize

initialize用户初始化缓存数据。比如score的缓存数据有两个:sum和count,需要初始化为sum=0.0和count=0L,第一个初始化为Double类型,第二个初始化为长整型。如下:

override def initialize(buffer: MutableAggregationBuffer): Unit = {
      //sum=0.0
      buffer(0)=0.0
      //count=0
      buffer(1)=0L
    }

 

6、update

当有新的输入数据时,update用户更新缓存变量。比如这里当有新的score输入时,需要将它的值更新变量sum中,并将count加1,如下:

    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      //输入非空
      if(!input.isNullAt(0)){
        //sum=sum+输入的score
        buffer(0)=buffer.getDouble(0)+input.getDouble(0)
        //count=count+1
        buffer(1)=buffer.getLong(1)+1
      }
    }

 

7、merge

merge将更新的缓存变量存入到缓存中。如下:

override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer1(0)=buffer1.getDouble(0)+buffer2.getDouble(0)
      buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
    }

 

8、evaluate

evaluate是一个计算方法,用于计算我们的最终结果。比如这里用于计算平均得分average(score)=sum(score)/count(score),如下:

override def evaluate(buffer: Row): Double = buffer.getDouble(0)/buffer.getLong(1)

这里我们自定义了一个MyAverage聚合函数用于计算score的平均值,如下:

package com.leboop.rdd

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

/**
  * 用户自定义集成算子Demo
  */
object MyAverageTest {
  /**
    * 读取itemdata.data数据,计算平均score
    */
  class MyAverage extends UserDefinedAggregateFunction{
    /**
      * 计算平均score,输入的应该是score这一列数据
      * StructField定义了列字段的名称score_column,字段的类型Double
      * StructType要求输入数StructField构成的数组Array,这里只有一列,所以与Nil运算生成Array
      * @return StructType
      */
    override def inputSchema: StructType = StructType(
      StructField("score_column",DoubleType)::Nil)

    /**
      * 缓存Schema,存储中间计算结果,
      * 比如计算平均score,需要计算score的总和和score的个数,然后average(score)=sum(score)/count(score)
      * 所以这里定义了StructType类型:两个StructField字段:sum和count
      * @return StructType
      */
    override def bufferSchema: StructType = StructType(
      StructField("sum",DoubleType)::StructField("count",LongType)::Nil)

    /**
      * 自定义集成算子最终返回的数据类型
      * 也就是average(score)的类型,所以是Double
      * @return DataType 返回数据类型
      */
    override def dataType: DataType = DoubleType

    /**
      * 数据一致性检验:对于同样的输入,输出是一样的
      * @return Boolean true 同样的输入,输出也一样
      */
    override def deterministic: Boolean = true

    /**
      * 初始化缓存sum和count
      * sum=0.0,count=0
      * @param buffer 中间数据
      */
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      //sum=0.0
      buffer(0)=0.0
      //count=0
      buffer(1)=0L
    }

    /**
      * 每次计算更新缓存
      * @param buffer 缓存数据
      * @param input 输入数据score
      */
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      //输入非空
      if(!input.isNullAt(0)){
        //sum=sum+输入的score
        buffer(0)=buffer.getDouble(0)+input.getDouble(0)
        //count=count+1
        buffer(1)=buffer.getLong(1)+1
      }
    }

    /**
      * 将更新后的buffer存储到缓存
      * @param buffer1 缓存
      * @param buffer2 更新后的buffer
      */
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer1(0)=buffer1.getDouble(0)+buffer2.getDouble(0)
      buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
    }

    /**
      * 计算最终的结果:average(score)=sum(score)/count(score)
      * @param buffer
      * @return
      */
    override def evaluate(buffer: Row): Double = buffer.getDouble(0)/buffer.getLong(1)

  }


  def main(args: Array[String]): Unit = {
    //创建Spark SQL切入点
    val spark = SparkSession.builder().master("local").appName("My-Average").getOrCreate()
    //注册名为myAverage的自定义集成算子MyAverage
    spark.udf.register("myAverage",MyAverage)
    //读取HDFS文件系统数据itemdata.data转换成指定列名的DataFrame
    val dataDF=spark.read.csv("hdfs://192.168.189.21:8020/input/mahout-demo/itemdata.data").toDF("user_id","item_id","score")
    //创建临时视图
    dataDF.createOrReplaceTempView("data")
    //通过sql计算平均工资
    spark.sql("select myAverage(score) as average_score from data").show()
  }
}

程序运行结果

+-----------------+
|    average_score|
+-----------------+
|3.257425742574257|
+-----------------+

 

二、类型安全的用户自定义聚合函数(Type-Safe User-Defined Aggregate Functions)

        实现类型安全的用户自定义聚合函数需要集成org.apache.spark.sql.expressions.Aggregator的Aggregator[K,V,C]抽象类,并且实现该类的6个函数。以上面计算score平均值的例子进行说明,并与无类型的用户自定于聚合函数对比。但是实现需要定义两个case class,如下:

  
case class Data(user_id: String, item_id: String, score: Double)
case class Average(var sum: Double,var count: Long)

Data用于存储itemdata.data数据,Average用于存储计算score平均值的中间数据,需要注意的是Average的参数sum和count都要声明为变量var。具体如下:

1、zero

        zero相当于1中的initialize初始化函数,初始化存储中间数据的Average,如下:

override def zero: Average = Average(0.0D, 0L)
 

2、reduce

        reduce函数相当于1中的update函数,当有新的数据a时,更新中间数据b,这里可使用+=复制(因为sum和count都是var),如下:

override def reduce(b: Average, a: Data): Average = {
      b.sum += a.score
      b.count += 1L
      b
    }

当然三行代码也可以直接写成Average(b.sum+a.score,b.count+1L),这样每次计算都会创建新的对象Average。

 

3、merge

        merge函数同1中的merge函数。如下:

override def merge(b1: Average, b2: Average): Average = {
      b1.sum+=b2.sum
      b1.count+= b2.count
      b1
    }

 

4、finish

        finish函数同1中的evaluate函数。计算最终的数据。如下:

override def finish(reduction: Average): Double = reduction.sum / reduction.count

 

5、bufferEncoder

        缓冲数据编码方式,如下:

override def bufferEncoder: Encoder[Average] = Encoders.product

 

6、outputEncoder

        最终数据输出编码方式,如下:

override def outputEncoder: Encoder[Double] = Encoders.scalaDouble

整体代码如下:

package com.leboop.rdd

import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator


/**
  * 类型安全自定义聚合函数
  */
object TypeSafeMyAverageTest {
  /**
    *Data类存储读取的文件数据
    */
  case class Data(user_id: String, item_id: String, score: Double)
  //Average
  case class Average(var sum: Double,var count: Long)

  object SafeMyAverage extends Aggregator[Data, Average, Double] {
    override def zero: Average = Average(0.0D, 0L)

    override def reduce(b: Average, a: Data): Average = {
      b.sum += a.score
      b.count += 1L
      b
    }

    override def merge(b1: Average, b2: Average): Average = {
      b1.sum+=b2.sum
      b1.count+= b2.count
      b1
    }

    override def finish(reduction: Average): Double = reduction.sum / reduction.count

    override def bufferEncoder: Encoder[Average] = Encoders.product

    override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
  }

  def main(args: Array[String]): Unit = {
    //创建Spark SQL切入点
    val spark = SparkSession.builder().master("local").appName("My-Average").getOrCreate()
    //读取HDFS文件系统数据itemdata.data生成RDD
    val rdd = spark.sparkContext.textFile("hdfs://192.168.189.21:8020/input/mahout-demo/itemdata.data")
    //RDD转化成DataSet
    import spark.implicits._
    val dataDS =rdd.map(_.split(",")).map(d => Data(d(0), d(1), d(2).toDouble)).toDS()
    //自定义聚合函数
    val averageScore = SafeMyAverage.toColumn.name("average_score")
    dataDS.select(averageScore).show()
  }
}

程序执行结果如下:

+-----------------+
|    average_score|
+-----------------+
|3.257425742574257|
+-----------------+

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

你可能感兴趣的:(Spark,大数据)