在Spark中,也支持Hive中的自定义函数。自定义函数大致可以分为三种:
- UDF(User-Defined-Function),即最基本的自定义函数,类似to_char,to_date等
- UDAF(User- Defined Aggregation Funcation),用户自定义聚合函数,类似在group by之后使用的sum,avg
- UDTF(User-Defined Table-Generating Functions),用户自定义生成函数,有点像stream里面的flatMap
一、自定义UDF 拼接三个参数,
1.1继承org.apache.spark.sql.api.java.UDFxx(1-22);
1.2、实现call方法
@Override
public String call(Long v1, String v2, String split) throws Exception {
return String.valueOf(v1) + split + v2;
}
完整代码实现
package com.chb.shopanalysis.hive.UDF;
import org.apache.spark.sql.api.java.UDF3;
public class ConcatLongStringUDF implements UDF3<Long, String, String, String> {
private static final long serialVersionUID = 1L;
@Override
public String call(Long v1, String v2, String split) throws Exception {
return String.valueOf(v1) + split + v2;
}
}
1.4、注册函数
sqlContext.udf().register(
"concat_long_string",
new ConcatLongStringUDF(),
DataTypes.StringType);
1.5、使用函数
/**
* 从hive表中读取数据, 使用自定义聚合函数
*/
private static void readProductClickInfo() {
String sql =
"SELECT city_id, city_name,"
+ "area,"
+ "product_id,"
+ "concat_long_string(city_id,city_name,':') city_infos "
+ "FROM click_product_basic ";
DataFrame df = sqlContext.sql(sql);
df.show();
}
二、用户自定义聚合函数UDAF
2.1、继承org.apache.spark.sql.expressions.UserDefinedAggregateFunction
2.2、定义输入,缓存,输出字段类型
private StructType inputSchema = DataTypes.createStructType(Arrays.asList(
DataTypes.createStructField("cityInfo", DataTypes.StringType, true)));
private StructType bufferSchema = DataTypes.createStructType(Arrays.asList(
DataTypes.createStructField("bufferCityInfo", DataTypes.StringType, true)));
private DataType dataType = DataTypes.StringType;
2.3、deterministic()
决定每次相同输入,是否返回相同输出, 一般都会设置为true.
@Override
public boolean deterministic() {
return deterministic;
}
2.4、初始化
/**
* 初始化
* 可以认为是,你自己在内部指定一个初始的值
*/
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, "");
}
2.5、更新, 这个是组类根据自己的逻辑进行拼接, 然后更新数据
/**
* 更新
* 可以认为是,一个一个地将组内的字段值传递进来
* 实现拼接的逻辑
*/
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
String bufferCityInfo = buffer.getString(0);
String cityInfo = input.getString(0);
if(!bufferCityInfo.contains(cityInfo)) {
if("".equals(bufferCityInfo)) {
bufferCityInfo += cityInfo;
} else {
bufferCityInfo += "," + cityInfo;
}
buffer.update(0, bufferCityInfo);
}
}
2.6、合并, 将所有节点的数据进行合并
/**
* 合并
* update操作,可能是针对一个分组内的部分数据,在某个节点上发生的
* 但是可能一个分组内的数据,会分布在多个节点上处理
* 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
*/
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
String bufferCityInfo1 = buffer1.getString(0);
String bufferCityInfo2 = buffer2.getString(0);
for(String cityInfo : bufferCityInfo2.split(",")) {
if(!bufferCityInfo1.contains(cityInfo)) {
if("".equals(bufferCityInfo1)) {
bufferCityInfo1 += cityInfo;
} else {
bufferCityInfo1 += "," + cityInfo;
}
}
}
buffer1.update(0, bufferCityInfo1);
}
2.7、输出最终结果, 可能我们需要的输出格式,可以在该方法中,进行格式化。
@Override
public Object evaluate(Row row) {
return row.getString(0);
}
2.8、注册函数
sqlContext.udf().register("group_concat_distinct",
new GroupConcatDistinctUDAF())
2.9、使用
/**
* 从hive表中读取数据, 使用自定义聚合函数
*/
private static void readProductClickInfo() {
String sql = "SELECT area, product_id,"
+ "count(*) click_count, "
+ "group_concat_distinct(concat_long_string(city_id,city_name,':')) city_infos "
+ "FROM click_product_basic "
+ "GROUP BY area,product_id ";
DataFrame df = sqlContext.sql(sql);
df.show();
df.registerTempTable("tmp_area_product_click_count");
}