SparkSql之UDF、UDAF、UDTF

UDF----------------------------------------

完整的示例:object SparkSQL {
  def main(args:Array[String]):Unit = {
  //创建SparkConf()并设置App名称
    val conf = new SparkConf().setAppName("SparkSQLDemo").setMaster("local")
    val spark = SparkSession.builder().config(conf).getOrCreate()
    val df: DataFrame = spark.read.json("dir/people.json")
    //注册函数,在整个应用中可以使用
    val addName = spark.udf.register("addName", (x: String) => "Name:" + x)
    df.createOrReplaceTempView("people")
    spark.sql("Select addName(name), age from people").show()
    spark.stop()
  }
}
------------------------

/**
 * 自定义UDF函数
 * 传入一个json形式的字符串,获取指定字段,返回改字段的值
 */
public class GetJsonObjectUDF implements UDF2 {

    private static final long serialVersionUID = 6776121915573178083L;

    @Override
    public String call(String json, String field) throws Exception {

        //fastjson自带的方法
        try {
            JSONObject jsonObject = JSONObject.parseObject(json);
            return jsonObject.getString(field);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }
}

----------------

/**
 * 将两个字段拼接起来(使用指定的分隔符)
 */
public class ConcatLongStringUDF implements UDF3 {

    @Override
    public String call(Long v1, String v2, String split) throws Exception {
        return String.valueOf(v1) + split + v2;
    }
}

-------------------

/**
 * random_prefix()   添加随机数前缀
 * @author Administrator
 *
 */
public class RandomPrefixUDF implements UDF2 {

   private static final long serialVersionUID = 1L;

   @Override
   public String call(String val, Integer num) throws Exception {
      Random random = new Random();
      int randNum = random.nextInt(10);
      return randNum + "_" + val;
   }
   
}

-----------------

**
 * 去除随机前缀
 * @author Administrator
 *
 */
public class RemoveRandomPrefixUDF implements UDF1 {

   private static final long serialVersionUID = 1L;

   @Override
   public String call(String val) throws Exception {
      String[] valSplited = val.split("_");
      return valSplited[1];
   }

}
UDAF----------------------------------------
把城市信息拼接起来,并且去重。结果数据: <1:北京,2:贵港。。。>

public class GroupConcatDistinctUDAF extends UserDefinedAggregateFunction {

    private static final long serialVersionUID = 8177502627955458491L;

    // 指定输入数据的字段与类型
   private StructType inputSchema = DataTypes.createStructType(Arrays.asList(
          DataTypes.createStructField("cityInfo",DataTypes.StringType,true)));

    // 指定缓冲数据的字段与类型
    private StructType bufferSchema = DataTypes.createStructType(Arrays.asList(
          DataTypes.createStructField("bufferCityInfo",DataTypes.StringType,true)));

    // 指定返回类型
    private DataType dataType = DataTypes.StringType;

    // 指定是否是确定性的
    private boolean deterministic = true;

    // 指定输入数据的字段与类型
    @Override
    public StructType inputSchema() {
        return inputSchema;
    }

    // 指定缓冲数据的字段与类型
    @Override
    public StructType bufferSchema() {
        return bufferSchema;
    }

    // 指定返回类型
    @Override
    public DataType dataType() {
        return dataType;
    }

    // 指定是否是确定性的
    @Override
    public boolean deterministic() {
        return deterministic;
    }

    //初始化
    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0,"");
    }

    /**
     * 更新
     * 可以认为是,一个一个地将组内的字段值传递进来
     * 实现拼接的逻辑
     * @param buffer
     * @param input
     */
    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {
        // 缓冲中的已经拼接过的城市信息串
        String bufferCityInfo = buffer.getString(0);

        // 刚刚传递进来的某个城市信息
        String cityInfo = input.getString(0);

        // 在这里要实现去重的逻辑
        if(!bufferCityInfo.contains(cityInfo)){
            if("".equals(bufferCityInfo)){
                bufferCityInfo += cityInfo;
            }else{
                bufferCityInfo += "," + cityInfo;
            }
            buffer.update(0,bufferCityInfo);    //
        }
    }

    /**
     * 合并
     * update操作,可能是针对一个分组内的部分数据,在某个节点上发生的
     *     但是可能一个分组内的数据,会分布在多个节点上处理
     *     用merge操作,将各个节点上分布式拼接好的串,合并起来
     * @param buffer1
     * @param buffer2
     */
    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        String bufferCityInfo1 = buffer1.getString(0);
        String bufferCityInfo2 = buffer2.getString(0);

        for(String cityInfo : bufferCityInfo2.split(",")){
            if(!bufferCityInfo1.contains(cityInfo)){
                if("".equals(bufferCityInfo1)){
                    bufferCityInfo1 += cityInfo;
                }else {
                    bufferCityInfo1 += "," + cityInfo;
                }
            }
        }
        buffer1.update(0,bufferCityInfo1);
    }

    @Override
    public Object evaluate(Row row) {
        return row.getString(0);
    }
}

 

UDTF----------------------------------------

//传入一个字符串,返回多个数组,每个数组只有一个值;跟hive的udtf很像

class UserDefinedUDTF extends GenericUDTF{

  //这个方法的作用:1.输入参数校验  2. 输出列定义,可以多于1列,相当于可以生成多行多列数据
  override def initialize(args:Array[ObjectInspector]): StructObjectInspector = {
    if (args.length != 1) {
      throw new UDFArgumentLengthException("UserDefinedUDTF takes only one argument")
    }
    if (args(0).getCategory() != ObjectInspector.Category.PRIMITIVE) {
      throw new UDFArgumentException("UserDefinedUDTF takes string as a parameter")
    }

    val fieldNames = new util.ArrayList[String]
    val fieldOIs = new util.ArrayList[ObjectInspector]

    //这里定义的是输出列默认字段名称
    fieldNames.add("col1")
    //这里定义的是输出列字段类型
    fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)

    ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
  }

  //这是处理数据的方法,入参数组里只有1行数据,即每次调用process方法只处理一行数据
  override def process(args: Array[AnyRef]): Unit = {
    //将字符串切分成单个字符的数组
    val strLst = args(0).toString.split("")
    for(i <- strLst){
      var tmp:Array[String] = new Array[String](1)
      tmp(0) = i
      //调用forward方法,必须传字符串数组,即使只有一个元素
      forward(tmp)
    }
  }

  override def close(): Unit = {}
}
 

你可能感兴趣的:(大数据资料笔记整理)