先解释一下什么是UDAF(User Defined Aggregate Function),即用户定义的聚合函数,聚合函数和普通函数的区别是什么呢,普通函数是接受一行输入产生一个输出,聚合函数是接受一组(一般是多行)输入然后产生一个输出,即将一组的值想办法聚合一下。
我们可能下意识的认为UDAF是需要和group by一起使用的,实际上UDAF可以跟group by一起使用,也可以不跟group by一起使用,这个其实比较好理解,联想到mysql中的max、min等函数,可以:
1 |
|
表示根据bar字段分组,然后求每个分组的最大值,这时候的分组有很多个,使用这个函数对每个分组进行处理,也可以:
1 |
|
这种情况可以将整张表看做是一个分组,然后在这个分组(实际上就是一整张表)中求最大值。所以聚合函数实际上是对分组做处理,而不关心分组中记录的具体数量。
使用UserDefinedAggregateFunction的套路:
1. 自定义类继承UserDefinedAggregateFunction,对每个阶段方法做实现
2. 在spark中注册UDAF,为其绑定一个名字
3. 然后就可以在sql语句中使用上面绑定的名字调用
要点:个人觉得开发过程中比较难理解的是update和merge两个函数的理解(请查看代码),只有充分的了解之后才能完成UDAF函数的开发。
实现的功能:统计同一文件指定日期内,文件的增长大小以及文件的增长率,因UDAF函数返回类型是DataTypes中的类型,没有List类型,所有将文件增长和增长率用“,”分格,使用DataTypes.String返回,在函数的使用过程中用substirng_index进行截取
CAST( substring_index(aggregate(size,insertDB_time),',',1) as bigint) as deltaSize,
substring_index(aggregate(size,insertDB_time),',',-1) as percent
代码如下:
package com.zbj.finance.HDFSMonitor.utils;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
public class SizePrecentUDAF extends UserDefinedAggregateFunction {
private StructType inputSchema;
private StructType bufferSchema;
public SizePrecentUDAF() {
List inputFields = new ArrayList<>();
inputFields.add(DataTypes.createStructField("size", DataTypes.LongType, true));
inputFields.add(DataTypes.createStructField("create_time", DataTypes.TimestampType, true));
inputSchema = DataTypes.createStructType(inputFields);
List bufferFields = new ArrayList<>();
bufferFields.add(DataTypes.createStructField("firstSize", DataTypes.LongType, true));
bufferFields.add(DataTypes.createStructField("endSize", DataTypes.LongType, true));
bufferFields.add(DataTypes.createStructField("min_time", DataTypes.TimestampType, true));
bufferFields.add(DataTypes.createStructField("max_time", DataTypes.TimestampType, true));
bufferSchema = DataTypes.createStructType(bufferFields);
}
@Override
public StructType inputSchema() {
return inputSchema;
}
@Override
public StructType bufferSchema() {
return bufferSchema;
}
@Override
public DataType dataType() {
return DataTypes.StringType;
}
@Override
public boolean deterministic() {
return true;
}
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, 0L);
buffer.update(1, 0L);
buffer.update(2, null);
buffer.update(3, null);
}
/**
* 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
* 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合
* 大聚和发生在reduce端.
* 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
* update的结果写入buffer中,每个分组中的每一行数据都要进行update操作
*/
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
if (buffer.getTimestamp(2) != null && input.getTimestamp(1).before(buffer.getTimestamp(2))) {
buffer.update(0, input.getLong(0));
buffer.update(2, input.getTimestamp(1));
}
if (buffer.getTimestamp(3) != null && input.getTimestamp(1).after(buffer.getTimestamp(3))) {
buffer.update(1, input.getLong(0));
buffer.update(3, input.getTimestamp(1));
}
if (buffer.getTimestamp(2) == null) {
buffer.update(0, input.getLong(0));
buffer.update(1, input.getLong(0));
buffer.update(2, input.getTimestamp(1));
buffer.update(3, input.getTimestamp(1));
}
}
/**
* 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
* 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
* 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
* 也可以是一个节点里面的多个executor合并 reduce端大聚合
* merge后的结果写如buffer1中
*/
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
if (buffer1.getTimestamp(2) == null) {
buffer1.update(0, buffer2.getLong(0));
buffer1.update(1, buffer2.getLong(1));
buffer1.update(2, buffer2.getTimestamp(2));
buffer1.update(3, buffer2.getTimestamp(3));
}
if (buffer2.getTimestamp(2).before(buffer1.getTimestamp(2))) {
buffer1.update(0, buffer2.getLong(0));
buffer1.update(2, buffer2.getTimestamp(2));
}
if (buffer2.getTimestamp(3).after(buffer1.getTimestamp(3))) {
buffer1.update(1, buffer2.getLong(1));
buffer1.update(3, buffer2.getTimestamp(3));
}
}
@Override
public Object evaluate(Row buffer) {
DecimalFormat df = new DecimalFormat("#.##");
long firstSize = buffer.getLong(0);
long minusSize = buffer.getLong(1) - buffer.getLong(0);
Double value = 0.0D;
if (minusSize == 0L) {
value = 0.0D;
} else if (firstSize != 0L) {
double d = (double) (minusSize) / (double) firstSize;
value = d * 100;
} else {
value = Double.MAX_VALUE;
}
String percentValue = df.format(value);
value = Double.parseDouble(percentValue);
return String.valueOf(minusSize) + "," + value;
}
}