Flink TableAPI和SQL(二十)UDF(三)聚合函数(Aggregate Functions)

文章目录

用户自定义聚合函数(User Defined AGGregate function,UDAGG)会把一行或多行数据(也就是一个表)聚合成一个标量值。这是一个标准的“多对一”的转换。

聚合函数的概念我们之前已经接触过多次,如 SUM()、MAX()、MIN()、AVG()、COUNT()都是常见的系统内置聚合函数。而如果有些需求无法直接调用系统函数解决,我们就必须自定义聚合函数来实现功能了。

自定义聚合函数需要继承抽象类 AggregateFunction。AggregateFunction 有两个泛型参数,T 表示聚合输出的结果类型,ACC 则表示聚合的中间状态类型。

Flink SQL 中的聚合函数的工作原理如下:
(1)首先,它需要创建一个累加器(accumulator),用来存储聚合的中间结果。这与DataStream API 中的 AggregateFunction 非常类似,累加器就可以看作是一个聚合状态。调用createAccumulator()方法可以创建一个空的累加器。
(2)对于输入的每一行数据,都会调用 accumulate()方法来更新累加器,这是聚合的核心过程。
(3)当所有的数据都处理完之后,通过调用 getValue()方法来计算并返回最终的结果。

所以,每个 AggregateFunction 都必须实现以下几个方法:

1) createAccumulator()
这是创建累加器的方法。没有输入参数,返回类型为累加器类型 ACC。

2) accumulate()
这是进行聚合计算的核心方法,每来一行数据都会调用。它的第一个参数是确定的,就是当前的累加器,类型为 ACC,表示当前聚合的中间状态;后面的参数则是聚合函数调用时传入的参数,可以有多个,类型也可以不同。这个方法主要是更新聚合状态,所以没有返回类型。需要注意的是,accumulate()与之前的求值方法 eval()类似,也是底层架构要求的,必须为 public,方法名必须为 accumulate,且无法直接 override、只能手动实现。

3) getValue()
这是得到最终返回结果的方法。输入参数是 ACC 类型的累加器,输出类型为 T。
在遇到复杂类型时,Flink 的类型推导可能会无法得到正确的结果。所以AggregateFunction也可以专门对累加器和返回结果的类型进行声明,这是通过 getAccumulatorType()和getResultType()两个方法来指定的。
除了上面的方法,还有几个方法是可选的。这些方法有些可以让查询更加高效,有些是在某些特定场景下必须要实现的。比如,如果是对会话窗口进行聚合,merge()方法就是必须要实现的,它会定义累加器的合并操作,而且这个方法对一些场景的优化也很有用;而如果聚合函数用在 OVER 窗口聚合中,就必须实现 retract()方法,保证数据可以进行撤回操作;

resetAccumulator()方法则是重置累加器,这在一些批处理场景中会比较有用。AggregateFunction 的所有方法都必须是 公有的(public),不能是静态的(static),而且名字必须跟上面写的完全一样。 createAccumulator 、 getValue 、 getResultType 以 及getAccumulatorType 这几个方法是在抽象类 AggregateFunction 中定义的,可以 override;而其他则都是底层架构约定的方法。

下面举一个具体的示例。在常用的系统内置聚合函数里,可以用 AVG()来计算平均值;如果我们现在希望计算的是某个字段的“加权平均值”,所以只能自定义一个聚合函数 WeightedAvg 来计算了。
为了计算加权平均值,应该从输入的每行数据中提取两个值作为参数:要计算的分数值 iValue,以及它的权重iWeight。而在聚合过程中,累加器(accumulator)需要存储当前的加权总和 sum,以及目前数据的个数 count。这可以用一个二元组来表示,也可以单独定义一个类 WeightedAverage ,里面包含 sum 和 count 两个属性,用它的对象实例来作为聚合的累加器。

public class UdfTest_AggregateFunction {
    public static void main(String[] args) throws Exception{
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(1);

        StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env);

        //1.在创建表的DDL中直接定义时间属性
        String creatDDL = "CREATE TABLE clickTable (" +
                "user_name STRING," +
                "url STRING," +
                "ts BIGINT," +
                "et AS TO_TIMESTAMP( FROM_UNIXTIME(ts / 1000))," + //事件时间  FROM_UNIXTIME() 能转换为年月日时分秒这样的格式 转换秒
                " WATERMARK FOR et AS et - INTERVAL '1' SECOND " + //watermark 延迟一秒
                ")WITH(" +
                " 'connector' = 'filesystem'," +
                " 'path' = 'input/clicks.txt'," +
                " 'format' = 'csv'" +
                ")";

        tableEnv.executeSql(creatDDL);

        //2.注册自定义聚合函数
        tableEnv.createTemporarySystemFunction("WeightedAvg", WeightedAverage.class);

        //3.调用UDF进行查询转换
        Table resultTable = tableEnv.sqlQuery("select user_name,WeightedAvg(ts,1) as w_avg from clickTable group by user_name");

        //4.转换成流打印
        tableEnv.toChangelogStream(resultTable).print();


        env.execute();
    }

    //单独定义累加器类型
    public static class WeightedAvgAccumulator{
        public long sum = 0;
        public int count = 0;
    }

    //实现自定义的聚合函数,计算平均时间戳
    public static class WeightedAverage extends AggregateFunction<Long,WeightedAvgAccumulator>{

        @Override
        public Long getValue(WeightedAvgAccumulator accumulator) {
            if (accumulator.count == 0){
                return null;
            }else{
                return accumulator.sum / accumulator.count;
            }

        }

        @Override
        public WeightedAvgAccumulator createAccumulator() {
            //初始化 累加器
            return new WeightedAvgAccumulator();
        }

        //累加计算的方法 (第一个是 WeightedAvgAccum 类型的累加器;
        //另外两个则是函数调用时输入的字段:要计算的值 iValue 和 对应的权重 iWeight。)
        public void accumulate(WeightedAvgAccumulator accumulator,Long iValue,Integer iWeight){
            accumulator.sum += iValue * iWeight;
            accumulator.count += iWeight;
        }
    }
}

Gitee上的源代码

聚合函数的 accumulate()方法有三个输入参数。第一个是 WeightedAvgAccum 类型的累加器;另外两个则是函数调用时输入的字段:要计算的值 ivalue 和 对应的权重 iweight。

你可能感兴趣的:(#,Flink,sql,flink)