UDF----------------------------------------
完整的示例:
object SparkSQL {
------------------------
def main(args:Array[String]):Unit = {
//创建SparkConf()并设置App名称
val conf = new SparkConf().setAppName("SparkSQLDemo").setMaster("local")
val spark = SparkSession.builder().config(conf).getOrCreate()
val df: DataFrame = spark.read.json("dir/people.json")
//注册函数,在整个应用中可以使用
val addName = spark.udf.register("addName", (x: String) => "Name:" + x)
df.createOrReplaceTempView("people")
spark.sql("Select addName(name), age from people").show()
spark.stop()
}
}
/**
* 自定义UDF函数
* 传入一个json形式的字符串,获取指定字段,返回改字段的值
*/
public class GetJsonObjectUDF implements UDF2 {
private static final long serialVersionUID = 6776121915573178083L;
@Override
public String call(String json, String field) throws Exception {
//fastjson自带的方法
try {
JSONObject jsonObject = JSONObject.parseObject(json);
return jsonObject.getString(field);
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
}
----------------
/**
* 将两个字段拼接起来(使用指定的分隔符)
*/
public class ConcatLongStringUDF implements UDF3 {
@Override
public String call(Long v1, String v2, String split) throws Exception {
return String.valueOf(v1) + split + v2;
}
}
-------------------
/**
* random_prefix() 添加随机数前缀
* @author Administrator
*
*/
public class RandomPrefixUDF implements UDF2 {
private static final long serialVersionUID = 1L;
@Override
public String call(String val, Integer num) throws Exception {
Random random = new Random();
int randNum = random.nextInt(10);
return randNum + "_" + val;
}
}
-----------------
**
* 去除随机前缀
* @author Administrator
*
*/
public class RemoveRandomPrefixUDF implements UDF1 {
private static final long serialVersionUID = 1L;
@Override
public String call(String val) throws Exception {
String[] valSplited = val.split("_");
return valSplited[1];
}
}
UDAF----------------------------------------
把城市信息拼接起来,并且去重。结果数据: <1:北京,2:贵港。。。>
public class GroupConcatDistinctUDAF extends UserDefinedAggregateFunction {
private static final long serialVersionUID = 8177502627955458491L;
// 指定输入数据的字段与类型
private StructType inputSchema = DataTypes.createStructType(Arrays.asList(
DataTypes.createStructField("cityInfo",DataTypes.StringType,true)));
// 指定缓冲数据的字段与类型
private StructType bufferSchema = DataTypes.createStructType(Arrays.asList(
DataTypes.createStructField("bufferCityInfo",DataTypes.StringType,true)));
// 指定返回类型
private DataType dataType = DataTypes.StringType;
// 指定是否是确定性的
private boolean deterministic = true;
// 指定输入数据的字段与类型
@Override
public StructType inputSchema() {
return inputSchema;
}
// 指定缓冲数据的字段与类型
@Override
public StructType bufferSchema() {
return bufferSchema;
}
// 指定返回类型
@Override
public DataType dataType() {
return dataType;
}
// 指定是否是确定性的
@Override
public boolean deterministic() {
return deterministic;
}
//初始化
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0,"");
}
/**
* 更新
* 可以认为是,一个一个地将组内的字段值传递进来
* 实现拼接的逻辑
* @param buffer
* @param input
*/
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
// 缓冲中的已经拼接过的城市信息串
String bufferCityInfo = buffer.getString(0);
// 刚刚传递进来的某个城市信息
String cityInfo = input.getString(0);
// 在这里要实现去重的逻辑
if(!bufferCityInfo.contains(cityInfo)){
if("".equals(bufferCityInfo)){
bufferCityInfo += cityInfo;
}else{
bufferCityInfo += "," + cityInfo;
}
buffer.update(0,bufferCityInfo); //
}
}
/**
* 合并
* update操作,可能是针对一个分组内的部分数据,在某个节点上发生的
* 但是可能一个分组内的数据,会分布在多个节点上处理
* 用merge操作,将各个节点上分布式拼接好的串,合并起来
* @param buffer1
* @param buffer2
*/
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
String bufferCityInfo1 = buffer1.getString(0);
String bufferCityInfo2 = buffer2.getString(0);
for(String cityInfo : bufferCityInfo2.split(",")){
if(!bufferCityInfo1.contains(cityInfo)){
if("".equals(bufferCityInfo1)){
bufferCityInfo1 += cityInfo;
}else {
bufferCityInfo1 += "," + cityInfo;
}
}
}
buffer1.update(0,bufferCityInfo1);
}
@Override
public Object evaluate(Row row) {
return row.getString(0);
}
}
UDTF----------------------------------------
//传入一个字符串,返回多个数组,每个数组只有一个值;跟hive的udtf很像
class UserDefinedUDTF extends GenericUDTF{
//这个方法的作用:1.输入参数校验 2. 输出列定义,可以多于1列,相当于可以生成多行多列数据
override def initialize(args:Array[ObjectInspector]): StructObjectInspector = {
if (args.length != 1) {
throw new UDFArgumentLengthException("UserDefinedUDTF takes only one argument")
}
if (args(0).getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentException("UserDefinedUDTF takes string as a parameter")
}
val fieldNames = new util.ArrayList[String]
val fieldOIs = new util.ArrayList[ObjectInspector]
//这里定义的是输出列默认字段名称
fieldNames.add("col1")
//这里定义的是输出列字段类型
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
}
//这是处理数据的方法,入参数组里只有1行数据,即每次调用process方法只处理一行数据
override def process(args: Array[AnyRef]): Unit = {
//将字符串切分成单个字符的数组
val strLst = args(0).toString.split("")
for(i <- strLst){
var tmp:Array[String] = new Array[String](1)
tmp(0) = i
//调用forward方法,必须传字符串数组,即使只有一个元素
forward(tmp)
}
}
override def close(): Unit = {}
}