Spark SQL 用户自定义函数UDF、用户自定义聚合函数UDAF

在Spark中,也支持Hive中的自定义函数。自定义函数大致可以分为三种:

  • UDF(User-Defined-Function),即最基本的自定义函数,类似to_char,to_date等
  • UDAF(User- Defined Aggregation Funcation),用户自定义聚合函数,类似在group by之后使用的sum,avg
  • UDTF(User-Defined Table-Generating Functions),用户自定义生成函数,有点像stream里面的flatMap

一、自定义UDF 拼接三个参数,

1.1继承org.apache.spark.sql.api.java.UDFxx(1-22);

1.2、实现call方法

    @Override
    public String call(Long v1, String v2, String split) throws Exception {
        return String.valueOf(v1) + split + v2;
    }

完整代码实现

package com.chb.shopanalysis.hive.UDF;

import org.apache.spark.sql.api.java.UDF3;

/**
 * 自定义UDF
 * 1 上海  split
 * 拼接成"1:上海"
 * 将两个字段拼接起来(使用指定的分隔符)
 * @author chb
 *
 */
public class ConcatLongStringUDF implements UDF3<Long, String, String, String> {

    private static final long serialVersionUID = 1L;

    @Override
    public String call(Long v1, String v2, String split) throws Exception {
        return String.valueOf(v1) + split + v2;
    }

}

1.4、注册函数

        // 注册自定义函数
        sqlContext.udf().register(
                "concat_long_string",       //自定义函数的名称
                new ConcatLongStringUDF(),  //自定义UDF对象
                DataTypes.StringType);      //返回数据类型

1.5、使用函数

    /**
     * 从hive表中读取数据, 使用自定义聚合函数
     */
    private static void readProductClickInfo() {

        // 可以获取到每个area下的每个product_id的城市信息拼接起来的串

        String sql = 
                "SELECT city_id, city_name,"
                    + "area,"
                    + "product_id,"
                    + "concat_long_string(city_id,city_name,':') city_infos "  
                + "FROM click_product_basic ";



        // 使用Spark SQL执行这条SQL语句
        DataFrame df = sqlContext.sql(sql);
        //展示结果
        df.show();

    }

Spark SQL 用户自定义函数UDF、用户自定义聚合函数UDAF_第1张图片

二、用户自定义聚合函数UDAF

2.1、继承org.apache.spark.sql.expressions.UserDefinedAggregateFunction

2.2、定义输入,缓存,输出字段类型

    // 指定输入数据的字段与类型
    private StructType inputSchema = DataTypes.createStructType(Arrays.asList(
            DataTypes.createStructField("cityInfo", DataTypes.StringType, true)));  
    // 指定缓冲数据的字段与类型
    private StructType bufferSchema = DataTypes.createStructType(Arrays.asList(
            DataTypes.createStructField("bufferCityInfo", DataTypes.StringType, true)));  
    // 指定返回类型
    private DataType dataType = DataTypes.StringType;

2.3、deterministic()决定每次相同输入,是否返回相同输出, 一般都会设置为true.

    @Override
    //每次相同的输入是否返回相同的输出
    public boolean deterministic() {
        return deterministic;
    }

2.4、初始化

    /**
     * 初始化
     * 可以认为是,你自己在内部指定一个初始的值
     */
    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0, "");  
    }

2.5、更新, 这个是组类根据自己的逻辑进行拼接, 然后更新数据

    /**
     * 更新
     * 可以认为是,一个一个地将组内的字段值传递进来
     * 实现拼接的逻辑
     */
    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {
        // 缓冲中的已经拼接过的城市信息串
        String bufferCityInfo = buffer.getString(0);
        // 刚刚传递进来的某个城市信息
        String cityInfo = input.getString(0);

        // 在这里要实现去重的逻辑
        // 判断:之前没有拼接过某个城市信息,那么这里才可以接下去拼接新的城市信息
        if(!bufferCityInfo.contains(cityInfo)) {
            if("".equals(bufferCityInfo)) {
                bufferCityInfo += cityInfo;
            } else {
                // 比如1:北京
                //2:上海
                //结果 1:北京,2:上海
                //再 来一个 1:北京  就不会拼接进去。
                bufferCityInfo += "," + cityInfo;
            }

            buffer.update(0, bufferCityInfo);  
        }
    }

2.6、合并, 将所有节点的数据进行合并

    /**
     * 合并
     * update操作,可能是针对一个分组内的部分数据,在某个节点上发生的
     * 但是可能一个分组内的数据,会分布在多个节点上处理
     * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
     */
    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        String bufferCityInfo1 = buffer1.getString(0);
        String bufferCityInfo2 = buffer2.getString(0);

        for(String cityInfo : bufferCityInfo2.split(",")) {
            if(!bufferCityInfo1.contains(cityInfo)) {
                if("".equals(bufferCityInfo1)) {
                    bufferCityInfo1 += cityInfo;
                } else {
                    bufferCityInfo1 += "," + cityInfo;
                }
            }
        }
        buffer1.update(0, bufferCityInfo1);  
    }

2.7、输出最终结果, 可能我们需要的输出格式,可以在该方法中,进行格式化。

        @Override
        //计算出最终结果
        public Object evaluate(Row row) {  
            return row.getString(0);  
        }

2.8、注册函数

        sqlContext.udf().register("group_concat_distinct", 
                new GroupConcatDistinctUDAF());

2.9、使用

    /**
     * 从hive表中读取数据, 使用自定义聚合函数
     */
    private static void readProductClickInfo() {
        // 按照area和product_id两个字段进行分组
        // 计算出各区域各商品的点击次数
        // 可以获取到每个area下的每个product_id的城市信息拼接起来的串

        String sql =  "SELECT  area, product_id,"
                + "count(*) click_count, "  
                + "group_concat_distinct(concat_long_string(city_id,city_name,':')) city_infos "  
                + "FROM click_product_basic "
                + "GROUP BY area,product_id "; 

        // 使用Spark SQL执行这条SQL语句
        DataFrame df = sqlContext.sql(sql);

        df.show();
        // 再次将查询出来的数据注册为一个临时表
        // 各区域各商品的点击次数(以及额外的城市列表)
        df.registerTempTable("tmp_area_product_click_count");    
    }

你可能感兴趣的:(#,spark)