UDAF开发流程及心得

一、UDAF简介

先解释一下什么是UDAF(User Defined Aggregate Function),即用户定义的聚合函数,聚合函数和普通函数的区别是什么呢,普通函数是接受一行输入产生一个输出,聚合函数是接受一组(一般是多行)输入然后产生一个输出,即将一组的值想办法聚合一下。

 

关于UDAF的一个误区

我们可能下意识的认为UDAF是需要和group by一起使用的,实际上UDAF可以跟group by一起使用,也可以不跟group by一起使用,这个其实比较好理解,联想到mysql中的max、min等函数,可以:

1

select max(foo) from foobar group by bar;

表示根据bar字段分组,然后求每个分组的最大值,这时候的分组有很多个,使用这个函数对每个分组进行处理,也可以:

1

select max(foo) from foobar;

这种情况可以将整张表看做是一个分组,然后在这个分组(实际上就是一整张表)中求最大值。所以聚合函数实际上是对分组做处理,而不关心分组中记录的具体数量。

 

二、UDAF使用

2.1 继承UserDefinedAggregateFunction

使用UserDefinedAggregateFunction的套路:

1. 自定义类继承UserDefinedAggregateFunction,对每个阶段方法做实现

2. 在spark中注册UDAF,为其绑定一个名字

3. 然后就可以在sql语句中使用上面绑定的名字调用

要点:个人觉得开发过程中比较难理解的是update和merge两个函数的理解(请查看代码),只有充分的了解之后才能完成UDAF函数的开发。

实现的功能:统计同一文件指定日期内,文件的增长大小以及文件的增长率,因UDAF函数返回类型是DataTypes中的类型,没有List类型,所有将文件增长和增长率用“,”分格,使用DataTypes.String返回,在函数的使用过程中用substirng_index进行截取

CAST( substring_index(aggregate(size,insertDB_time),',',1) as bigint) as deltaSize,

substring_index(aggregate(size,insertDB_time),',',-1) as percent

代码如下:

package com.zbj.finance.HDFSMonitor.utils;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;

public class SizePrecentUDAF extends UserDefinedAggregateFunction {

    private StructType inputSchema;
    private StructType bufferSchema;

    public SizePrecentUDAF() {
        List inputFields = new ArrayList<>();
        inputFields.add(DataTypes.createStructField("size", DataTypes.LongType, true));
        inputFields.add(DataTypes.createStructField("create_time", DataTypes.TimestampType, true));
        inputSchema = DataTypes.createStructType(inputFields);

        List bufferFields = new ArrayList<>();
        bufferFields.add(DataTypes.createStructField("firstSize", DataTypes.LongType, true));
        bufferFields.add(DataTypes.createStructField("endSize", DataTypes.LongType, true));
        bufferFields.add(DataTypes.createStructField("min_time", DataTypes.TimestampType, true));
        bufferFields.add(DataTypes.createStructField("max_time", DataTypes.TimestampType, true));
        bufferSchema = DataTypes.createStructType(bufferFields);
    }

    @Override
    public StructType inputSchema() {
        return inputSchema;
    }

    @Override
    public StructType bufferSchema() {
        return bufferSchema;
    }

    @Override
    public DataType dataType() {
        return DataTypes.StringType;
    }

    @Override
    public boolean deterministic() {
        return true;
    }

    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0, 0L);
        buffer.update(1, 0L);
        buffer.update(2, null);
        buffer.update(3, null);
    }

    /**
     * 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
     * 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合
     * 大聚和发生在reduce端.
     * 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
     * update的结果写入buffer中,每个分组中的每一行数据都要进行update操作
     */
    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {
        if (buffer.getTimestamp(2) != null && input.getTimestamp(1).before(buffer.getTimestamp(2))) {
            buffer.update(0, input.getLong(0));
            buffer.update(2, input.getTimestamp(1));
        }
        if (buffer.getTimestamp(3) != null && input.getTimestamp(1).after(buffer.getTimestamp(3))) {
            buffer.update(1, input.getLong(0));
            buffer.update(3, input.getTimestamp(1));
        }
        if (buffer.getTimestamp(2) == null) {
            buffer.update(0, input.getLong(0));
            buffer.update(1, input.getLong(0));
            buffer.update(2, input.getTimestamp(1));
            buffer.update(3, input.getTimestamp(1));
        }
    }

    /**
     * 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
     * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
     * 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
     * 也可以是一个节点里面的多个executor合并 reduce端大聚合
     * merge后的结果写如buffer1中
     */
    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        if (buffer1.getTimestamp(2) == null) {
            buffer1.update(0, buffer2.getLong(0));
            buffer1.update(1, buffer2.getLong(1));
            buffer1.update(2, buffer2.getTimestamp(2));
            buffer1.update(3, buffer2.getTimestamp(3));
        }
        if (buffer2.getTimestamp(2).before(buffer1.getTimestamp(2))) {
            buffer1.update(0, buffer2.getLong(0));
            buffer1.update(2, buffer2.getTimestamp(2));
        }
        if (buffer2.getTimestamp(3).after(buffer1.getTimestamp(3))) {
            buffer1.update(1, buffer2.getLong(1));
            buffer1.update(3, buffer2.getTimestamp(3));
        }
    }

    @Override
    public Object evaluate(Row buffer) {
        DecimalFormat df = new DecimalFormat("#.##");
        long firstSize = buffer.getLong(0);
        long minusSize = buffer.getLong(1) - buffer.getLong(0);
        Double value = 0.0D;
        if (minusSize == 0L) {
            value = 0.0D;
        } else if (firstSize != 0L) {
            double d = (double) (minusSize) / (double) firstSize;
            value = d * 100;
        } else {
            value = Double.MAX_VALUE;
        }
        String percentValue = df.format(value);
        value = Double.parseDouble(percentValue);
        return String.valueOf(minusSize) + "," + value;
    }
}

 

 

你可能感兴趣的:(SparkSql)