Spark SQL 中 UDF 和 UDAF 的使用

Spark SQL 支持 Hive 的 UDF(User defined functions) 和 UDAF(User defined aggregation functions)

UDF 传入参数只能是表中的 1 行数据(可以是多列字段),传出参数也是 1 行,具体使用如下:

/**
 * 拼接一行中两列字段,数据类型一个为长整型,一个为字符串
 * Created by zhulei on 2017/6/20.
 */
public class ConcatLongStringUDF implements UDF3<Long, String, String, String> {
    @Override
    public String call(Long v1, String v2, String split) throws Exception {
        return String.valueOf(v1) + split + v2;
    }
}

//然后在 main 方法中注册
sqlContext.udf().register("concat_long_string", new ConcatLongStringUDF(), DataTypes.StringType);

UDAF 传入参数是多行的数据,然后通过聚合运算输出一行数据,具体使用如下:

/**
 * 

* 组内拼接去重函数 * 多行输入,聚合成一行输出 * Created by zhulei on 2017/6/20. */ public class GroupConcatDistinctUDAF extends UserDefinedAggregateFunction { /** * 定义输入数据的 schema * 比如你要将多行多列的数据合并,可以理解成输入多行多列的数据所对应的 schema * 这里输入的只有一列数据,所以 schema 也就只有一个字段 */ @Override public StructType inputSchema() { return DataTypes.createStructType(Collections.singletonList( DataTypes.createStructField("cityInfo", DataTypes.StringType, true))); } /** * 定义用来存储中间计算结果的 buffer 对应的 schema * 这个值是根据你的计算过程来定的 */ @Override public StructType bufferSchema() { return DataTypes.createStructType(Collections.singletonList( DataTypes.createStructField("bufferCityInfo", DataTypes.StringType, true) )); } /** * 输出值的数据类型 */ @Override public DataType dataType() { return DataTypes.StringType; } /** * 输入值和输出值是不是确定的 */ @Override public boolean deterministic() { return true; } /** * 初始化中间计算结果变量 */ @Override public void initialize(MutableAggregationBuffer buffer) { buffer.update(0, ""); } /** * 更新计算结果 * 不断的将每个输入值通过你的计算方法去计算 */ @Override public void update(MutableAggregationBuffer buffer, Row input) { String bufferCityIno = buffer.getString(0); String inputCityInfo = input.getString(0); if (!bufferCityIno.contains(inputCityInfo)) { if ("".equals(bufferCityIno)) { bufferCityIno += inputCityInfo; } else { bufferCityIno += "," + inputCityInfo; } buffer.update(0, bufferCityIno); } } /** * update 操作是某个节点上的计算 * merge 是将多个节点的结果进行合并 */ @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { String aggBuffer1 = buffer1.getString(0); String aggBuffer2 = buffer2.getString(0); for (String ele : aggBuffer2.split(",")) { if (!aggBuffer1.contains(ele)) { if ("".equals(aggBuffer1)) { aggBuffer1 += ele; } else { aggBuffer1 += "," + ele; } } } buffer1.update(0, aggBuffer1); } /** * 输出最终计算结果 */ @Override public Object evaluate(Row buffer) { return buffer.getString(0); } } //然后在 main 方法中注册 sqlContext.udf().register("group_concat_distinct", new GroupConcatDistinctUDAF());

你可能感兴趣的:(BigData)