实现无类型的用户自定于聚合函数需要继承抽象类UserDefinedAggregateFunction,并重写该类的8个函数。我们以计算数据类型为Double的列score的平均值为例进行详细说明。score来源于数据文件itemdata.data,格式如下:
0162381440670851711,4,7.0
0162381440670851711,11,4.0
0162381440670851711,32,1.0
0162381440670851711,176,27.0
0162381440670851711,183,11.0
0162381440670851711,184,5.0
0162381440670851711,207,9.0
0162381440670851711,256,3.0
0162381440670851711,258,4.0
0162381440670851711,259,16.0
0162381440670851711,260,8.0
0162381440670851711,261,18.0
0162381440670851711,301,1.0
第一列为user_id,第二列为item_id,第三列为score。
定义输入数据的Schema,要求类型是StructType,它的参数是由StructField类型构成的数组。比如这里要定义score列的Schema,首先使用StructField声明score列的名字score_column,数据类型为DoubleType。这里输入只有score这一列,所以StructField构成的数组只有一个元素。如下:
override def inputSchema: StructType = StructType(StructField("score_column",DoubleType)::Nil)
::是Scala中的操作符与Nil空集合操作后生成一个数组。
事实上,计算score的平均值时,需要用到score的总和sum以及score的总个数count这样的中间数据,那么就使用bufferSchema来定义它们。如下:
override def bufferSchema: StructType = StructType(StructField("sum",DoubleType)::StructField("count",LongType)::Nil)
这里StructField类型的数组就有两个元素:数据类型为DoubleType的sum和数据类型为LongType类型的count。
我们需要对自定义聚合函数的最终数据类型进行说明,使用dataType函数。比如计算出的平均score是Double类型,如下定义:
override def dataType: DataType = DoubleType
deterministic函数用于对输入数据进行一致性检验,是一个布尔值,当为true时,表示对于同样的输入会得到同样的输出。因为对于同样的score输入,肯定要得到相同的score平均值,所以定义为true,如下:
override def deterministic: Boolean = true
initialize用户初始化缓存数据。比如score的缓存数据有两个:sum和count,需要初始化为sum=0.0和count=0L,第一个初始化为Double类型,第二个初始化为长整型。如下:
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//sum=0.0
buffer(0)=0.0
//count=0
buffer(1)=0L
}
当有新的输入数据时,update用户更新缓存变量。比如这里当有新的score输入时,需要将它的值更新变量sum中,并将count加1,如下:
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//输入非空
if(!input.isNullAt(0)){
//sum=sum+输入的score
buffer(0)=buffer.getDouble(0)+input.getDouble(0)
//count=count+1
buffer(1)=buffer.getLong(1)+1
}
}
merge将更新的缓存变量存入到缓存中。如下:
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0)=buffer1.getDouble(0)+buffer2.getDouble(0)
buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
}
evaluate是一个计算方法,用于计算我们的最终结果。比如这里用于计算平均得分average(score)=sum(score)/count(score),如下:
override def evaluate(buffer: Row): Double = buffer.getDouble(0)/buffer.getLong(1)
这里我们自定义了一个MyAverage聚合函数用于计算score的平均值,如下:
package com.leboop.rdd
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
* 用户自定义集成算子Demo
*/
object MyAverageTest {
/**
* 读取itemdata.data数据,计算平均score
*/
class MyAverage extends UserDefinedAggregateFunction{
/**
* 计算平均score,输入的应该是score这一列数据
* StructField定义了列字段的名称score_column,字段的类型Double
* StructType要求输入数StructField构成的数组Array,这里只有一列,所以与Nil运算生成Array
* @return StructType
*/
override def inputSchema: StructType = StructType(
StructField("score_column",DoubleType)::Nil)
/**
* 缓存Schema,存储中间计算结果,
* 比如计算平均score,需要计算score的总和和score的个数,然后average(score)=sum(score)/count(score)
* 所以这里定义了StructType类型:两个StructField字段:sum和count
* @return StructType
*/
override def bufferSchema: StructType = StructType(
StructField("sum",DoubleType)::StructField("count",LongType)::Nil)
/**
* 自定义集成算子最终返回的数据类型
* 也就是average(score)的类型,所以是Double
* @return DataType 返回数据类型
*/
override def dataType: DataType = DoubleType
/**
* 数据一致性检验:对于同样的输入,输出是一样的
* @return Boolean true 同样的输入,输出也一样
*/
override def deterministic: Boolean = true
/**
* 初始化缓存sum和count
* sum=0.0,count=0
* @param buffer 中间数据
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//sum=0.0
buffer(0)=0.0
//count=0
buffer(1)=0L
}
/**
* 每次计算更新缓存
* @param buffer 缓存数据
* @param input 输入数据score
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//输入非空
if(!input.isNullAt(0)){
//sum=sum+输入的score
buffer(0)=buffer.getDouble(0)+input.getDouble(0)
//count=count+1
buffer(1)=buffer.getLong(1)+1
}
}
/**
* 将更新后的buffer存储到缓存
* @param buffer1 缓存
* @param buffer2 更新后的buffer
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0)=buffer1.getDouble(0)+buffer2.getDouble(0)
buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
}
/**
* 计算最终的结果:average(score)=sum(score)/count(score)
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Double = buffer.getDouble(0)/buffer.getLong(1)
}
def main(args: Array[String]): Unit = {
//创建Spark SQL切入点
val spark = SparkSession.builder().master("local").appName("My-Average").getOrCreate()
//注册名为myAverage的自定义集成算子MyAverage
spark.udf.register("myAverage",MyAverage)
//读取HDFS文件系统数据itemdata.data转换成指定列名的DataFrame
val dataDF=spark.read.csv("hdfs://192.168.189.21:8020/input/mahout-demo/itemdata.data").toDF("user_id","item_id","score")
//创建临时视图
dataDF.createOrReplaceTempView("data")
//通过sql计算平均工资
spark.sql("select myAverage(score) as average_score from data").show()
}
}
程序运行结果
+-----------------+
| average_score|
+-----------------+
|3.257425742574257|
+-----------------+
实现类型安全的用户自定义聚合函数需要集成org.apache.spark.sql.expressions.Aggregator的Aggregator[K,V,C]抽象类,并且实现该类的6个函数。以上面计算score平均值的例子进行说明,并与无类型的用户自定于聚合函数对比。但是实现需要定义两个case class,如下:
case class Data(user_id: String, item_id: String, score: Double)
case class Average(var sum: Double,var count: Long)
Data用于存储itemdata.data数据,Average用于存储计算score平均值的中间数据,需要注意的是Average的参数sum和count都要声明为变量var。具体如下:
zero相当于1中的initialize初始化函数,初始化存储中间数据的Average,如下:
override def zero: Average = Average(0.0D, 0L)
reduce函数相当于1中的update函数,当有新的数据a时,更新中间数据b,这里可使用+=复制(因为sum和count都是var),如下:
override def reduce(b: Average, a: Data): Average = {
b.sum += a.score
b.count += 1L
b
}
当然三行代码也可以直接写成Average(b.sum+a.score,b.count+1L),这样每次计算都会创建新的对象Average。
merge函数同1中的merge函数。如下:
override def merge(b1: Average, b2: Average): Average = {
b1.sum+=b2.sum
b1.count+= b2.count
b1
}
finish函数同1中的evaluate函数。计算最终的数据。如下:
override def finish(reduction: Average): Double = reduction.sum / reduction.count
缓冲数据编码方式,如下:
override def bufferEncoder: Encoder[Average] = Encoders.product
最终数据输出编码方式,如下:
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
整体代码如下:
package com.leboop.rdd
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator
/**
* 类型安全自定义聚合函数
*/
object TypeSafeMyAverageTest {
/**
*Data类存储读取的文件数据
*/
case class Data(user_id: String, item_id: String, score: Double)
//Average
case class Average(var sum: Double,var count: Long)
object SafeMyAverage extends Aggregator[Data, Average, Double] {
override def zero: Average = Average(0.0D, 0L)
override def reduce(b: Average, a: Data): Average = {
b.sum += a.score
b.count += 1L
b
}
override def merge(b1: Average, b2: Average): Average = {
b1.sum+=b2.sum
b1.count+= b2.count
b1
}
override def finish(reduction: Average): Double = reduction.sum / reduction.count
override def bufferEncoder: Encoder[Average] = Encoders.product
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
def main(args: Array[String]): Unit = {
//创建Spark SQL切入点
val spark = SparkSession.builder().master("local").appName("My-Average").getOrCreate()
//读取HDFS文件系统数据itemdata.data生成RDD
val rdd = spark.sparkContext.textFile("hdfs://192.168.189.21:8020/input/mahout-demo/itemdata.data")
//RDD转化成DataSet
import spark.implicits._
val dataDS =rdd.map(_.split(",")).map(d => Data(d(0), d(1), d(2).toDouble)).toDS()
//自定义聚合函数
val averageScore = SafeMyAverage.toColumn.name("average_score")
dataDS.select(averageScore).show()
}
}
程序执行结果如下:
+-----------------+
| average_score|
+-----------------+
|3.257425742574257|
+-----------------+