SparkSQL 之 基于Java实现UDF和UDAF详解

UDF

源码:最多传入参数为22个

//传入两个参数
def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = {
  val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any)
  functionRegistry.registerFunction(
    name,
    (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

注册:spark.udf().register(函数名,函数体,函数输出类型);

public static void main(String[] args) {
    SparkSession spark = SparkSession
            .builder()
            .appName("SqlDataSource")
            .master("local")
            .getOrCreate();

    //保留两位小数,四舍五入
     spark.udf().register("twoDecimal", new UDF1() {
         @Override
         public Double call(Double in) throws Exception {
             BigDecimal b = new BigDecimal(in);
             double res = b.setScale(2,BigDecimal.ROUND_HALF_DOWN).doubleValue();
             return res;
         }

}

UDTF

继承 extends UserDefinedAggregateFunction,重写其中8个方法:
下边的以求平均值为案例:

public class MyUDAF extends UserDefinedAggregateFunction {
    private StructType inputSchema;
    private StructType bufferSchema;

    public MyUDAF() {
        List inputFields = new ArrayList<>();
        inputFields.add(DataTypes.createStructField("inputColumn",DataTypes.DoubleType,true));
        inputSchema = DataTypes.createStructType(inputFields);

        List bufferFields = new ArrayList<>();
        bufferFields.add(DataTypes.createStructField("sum",DataTypes.DoubleType,true));
        bufferFields.add(DataTypes.createStructField("count",DataTypes.DoubleType,true));
        bufferSchema = DataTypes.createStructType(bufferFields);
    }

    //1、该聚合函数的输入参数的数据类型
    public StructType inputSchema() {
        return inputSchema;
    }

    //2、聚合缓冲区中的数据类型.(有序性)
    public StructType bufferSchema() {
        return bufferSchema;
    }

    //3、返回值的数据类型
    public DataType dataType() {
        return DataTypes.DoubleType;
    }

    //4、这个函数是否总是在相同的输入上返回相同的输出,一般为true
    public boolean deterministic() {
        return true;
    }

    //5、初始化给定的聚合缓冲区,在索引值为0的sum=0;索引值为1的count=1;
    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0,0D);
        buffer.update(1,0D);
    }

    //6、更新
    public void update(MutableAggregationBuffer buffer, Row input) {
        //如果input的索引值为0的值不为0
        if(!input.isNullAt(0)){
            double updateSum = buffer.getDouble(0) + input.getDouble(0);
            double updateCount = buffer.getDouble(1) + 1;
            buffer.update(0,updateSum);
            buffer.update(1,updateCount);
        }
    }

    //7、合并两个聚合缓冲区,并将更新后的缓冲区值存储回“buffer1”
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        double mergeSum = buffer1.getDouble(0) + buffer2.getDouble(0);
        double mergeCount = buffer1.getDouble(1) + buffer2.getDouble(1);
        buffer1.update(0,mergeSum);
        buffer1.update(1,mergeCount);

    }

    //8、计算出最终结果
    public Double evaluate(Row buffer) {
        return buffer.getDouble(0)/buffer.getDouble(1);
    }
}

调用UDAF,实现

public class RunMyUDAF {
    public static void main(String[] args) {
        SparkSession spark = SparkSession
                .builder()
                .appName("RunMyUDAF")
                .master("local")
                .getOrCreate();

        // Register the function to access it
        spark.udf().register("myAverage", new MyUDAF());

        Dataset df = spark.read().json("src/main/resources/employees.json");
        df.createOrReplaceTempView("employees");
        df.show();

//        +-------+------+
//        |   name|salary|
//        +-------+------+
//        |Michael|     0|
//        |   Andy|  4537|
//        | Justin|  3500|
//        |  Berta|     0|
//        |Michael|  3000|
//        |   Andy|  4500|
//        | Justin|  3500|
//        |  Berta|  4000|
//        |   Andy|  4500|
//        +-------+------+

        //保留两位小数,四舍五入
        spark.udf().register("twoDecimal", new UDF1() {
            @Override
            public Double call(Double in) throws Exception {
                BigDecimal b = new BigDecimal(in);
                double res = b.setScale(2,BigDecimal.ROUND_HALF_DOWN).doubleValue();
                return res;
            }
        }, DataTypes.DoubleType);


        Dataset result = spark
        .sql("SELECT name,twoDecimal(myAverage(salary)) as avg_salary FROM employees group by name");
        result.show();

//       +-------+--------------+
//       |   name|  avg_salary  |
//       +-------+--------------+
//       |Michael|        1500.0|
//       |   Andy|       4512.33|
//       |  Berta|        2000.0|
//       | Justin|        3500.0|
//       +-------+--------------+

        spark.stop();
    }
}

你可能感兴趣的:(spark)