用户自定义聚合函数(User Defined AGGregate function,UDAGG)会把一行或多行数据(也就是一个表)聚合成一个标量值。这是一个标准的“多对一”的转换。
聚合函数的概念我们之前已经接触过多次,如 SUM()、MAX()、MIN()、AVG()、COUNT()都是常见的系统内置聚合函数。而如果有些需求无法直接调用系统函数解决,我们就必须自定义聚合函数来实现功能了。
自定义聚合函数需要继承抽象类 AggregateFunction。AggregateFunction 有两个泛型参数
Flink SQL 中的聚合函数的工作原理如下:
(1)首先,它需要创建一个累加器(accumulator),用来存储聚合的中间结果。这与DataStream API 中的 AggregateFunction 非常类似,累加器就可以看作是一个聚合状态。调用createAccumulator()方法可以创建一个空的累加器。
(2)对于输入的每一行数据,都会调用 accumulate()方法来更新累加器,这是聚合的核心过程。
(3)当所有的数据都处理完之后,通过调用 getValue()方法来计算并返回最终的结果。
所以,每个 AggregateFunction 都必须实现以下几个方法:
1) createAccumulator()
这是创建累加器的方法。没有输入参数,返回类型为累加器类型 ACC。
2) accumulate()
这是进行聚合计算的核心方法,每来一行数据都会调用。它的第一个参数是确定的,就是当前的累加器,类型为 ACC,表示当前聚合的中间状态;后面的参数则是聚合函数调用时传入的参数,可以有多个,类型也可以不同。这个方法主要是更新聚合状态,所以没有返回类型。需要注意的是,accumulate()与之前的求值方法 eval()类似,也是底层架构要求的,必须为 public,方法名必须为 accumulate,且无法直接 override、只能手动实现。
3) getValue()
这是得到最终返回结果的方法。输入参数是 ACC 类型的累加器,输出类型为 T。
在遇到复杂类型时,Flink 的类型推导可能会无法得到正确的结果。所以AggregateFunction也可以专门对累加器和返回结果的类型进行声明,这是通过 getAccumulatorType()和getResultType()两个方法来指定的。
除了上面的方法,还有几个方法是可选的。这些方法有些可以让查询更加高效,有些是在某些特定场景下必须要实现的。比如,如果是对会话窗口进行聚合,merge()方法就是必须要实现的,它会定义累加器的合并操作,而且这个方法对一些场景的优化也很有用;而如果聚合函数用在 OVER 窗口聚合中,就必须实现 retract()方法,保证数据可以进行撤回操作;
resetAccumulator()方法则是重置累加器,这在一些批处理场景中会比较有用。AggregateFunction 的所有方法都必须是 公有的(public),不能是静态的(static),而且名字必须跟上面写的完全一样。 createAccumulator 、 getValue 、 getResultType 以 及getAccumulatorType 这几个方法是在抽象类 AggregateFunction 中定义的,可以 override;而其他则都是底层架构约定的方法。
下面举一个具体的示例。在常用的系统内置聚合函数里,可以用 AVG()来计算平均值;如果我们现在希望计算的是某个字段的“加权平均值”,所以只能自定义一个聚合函数 WeightedAvg 来计算了。
为了计算加权平均值,应该从输入的每行数据中提取两个值作为参数:要计算的分数值 iValue,以及它的权重iWeight。而在聚合过程中,累加器(accumulator)需要存储当前的加权总和 sum,以及目前数据的个数 count。这可以用一个二元组来表示,也可以单独定义一个类 WeightedAverage ,里面包含 sum 和 count 两个属性,用它的对象实例来作为聚合的累加器。
public class UdfTest_AggregateFunction {
public static void main(String[] args) throws Exception{
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(1);
StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env);
//1.在创建表的DDL中直接定义时间属性
String creatDDL = "CREATE TABLE clickTable (" +
"user_name STRING," +
"url STRING," +
"ts BIGINT," +
"et AS TO_TIMESTAMP( FROM_UNIXTIME(ts / 1000))," + //事件时间 FROM_UNIXTIME() 能转换为年月日时分秒这样的格式 转换秒
" WATERMARK FOR et AS et - INTERVAL '1' SECOND " + //watermark 延迟一秒
")WITH(" +
" 'connector' = 'filesystem'," +
" 'path' = 'input/clicks.txt'," +
" 'format' = 'csv'" +
")";
tableEnv.executeSql(creatDDL);
//2.注册自定义聚合函数
tableEnv.createTemporarySystemFunction("WeightedAvg", WeightedAverage.class);
//3.调用UDF进行查询转换
Table resultTable = tableEnv.sqlQuery("select user_name,WeightedAvg(ts,1) as w_avg from clickTable group by user_name");
//4.转换成流打印
tableEnv.toChangelogStream(resultTable).print();
env.execute();
}
//单独定义累加器类型
public static class WeightedAvgAccumulator{
public long sum = 0;
public int count = 0;
}
//实现自定义的聚合函数,计算平均时间戳
public static class WeightedAverage extends AggregateFunction<Long,WeightedAvgAccumulator>{
@Override
public Long getValue(WeightedAvgAccumulator accumulator) {
if (accumulator.count == 0){
return null;
}else{
return accumulator.sum / accumulator.count;
}
}
@Override
public WeightedAvgAccumulator createAccumulator() {
//初始化 累加器
return new WeightedAvgAccumulator();
}
//累加计算的方法 (第一个是 WeightedAvgAccum 类型的累加器;
//另外两个则是函数调用时输入的字段:要计算的值 iValue 和 对应的权重 iWeight。)
public void accumulate(WeightedAvgAccumulator accumulator,Long iValue,Integer iWeight){
accumulator.sum += iValue * iWeight;
accumulator.count += iWeight;
}
}
}
Gitee上的源代码
聚合函数的 accumulate()方法有三个输入参数。第一个是 WeightedAvgAccum 类型的累加器;另外两个则是函数调用时输入的字段:要计算的值 ivalue 和 对应的权重 iweight。