源码:最多传入参数为22个
//传入两个参数
def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}
注册:spark.udf().register(函数名,函数体,函数输出类型);
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("SqlDataSource")
.master("local")
.getOrCreate();
//保留两位小数,四舍五入
spark.udf().register("twoDecimal", new UDF1() {
@Override
public Double call(Double in) throws Exception {
BigDecimal b = new BigDecimal(in);
double res = b.setScale(2,BigDecimal.ROUND_HALF_DOWN).doubleValue();
return res;
}
}
继承 extends UserDefinedAggregateFunction,重写其中8个方法:
下边的以求平均值为案例:
public class MyUDAF extends UserDefinedAggregateFunction {
private StructType inputSchema;
private StructType bufferSchema;
public MyUDAF() {
List inputFields = new ArrayList<>();
inputFields.add(DataTypes.createStructField("inputColumn",DataTypes.DoubleType,true));
inputSchema = DataTypes.createStructType(inputFields);
List bufferFields = new ArrayList<>();
bufferFields.add(DataTypes.createStructField("sum",DataTypes.DoubleType,true));
bufferFields.add(DataTypes.createStructField("count",DataTypes.DoubleType,true));
bufferSchema = DataTypes.createStructType(bufferFields);
}
//1、该聚合函数的输入参数的数据类型
public StructType inputSchema() {
return inputSchema;
}
//2、聚合缓冲区中的数据类型.(有序性)
public StructType bufferSchema() {
return bufferSchema;
}
//3、返回值的数据类型
public DataType dataType() {
return DataTypes.DoubleType;
}
//4、这个函数是否总是在相同的输入上返回相同的输出,一般为true
public boolean deterministic() {
return true;
}
//5、初始化给定的聚合缓冲区,在索引值为0的sum=0;索引值为1的count=1;
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0,0D);
buffer.update(1,0D);
}
//6、更新
public void update(MutableAggregationBuffer buffer, Row input) {
//如果input的索引值为0的值不为0
if(!input.isNullAt(0)){
double updateSum = buffer.getDouble(0) + input.getDouble(0);
double updateCount = buffer.getDouble(1) + 1;
buffer.update(0,updateSum);
buffer.update(1,updateCount);
}
}
//7、合并两个聚合缓冲区,并将更新后的缓冲区值存储回“buffer1”
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
double mergeSum = buffer1.getDouble(0) + buffer2.getDouble(0);
double mergeCount = buffer1.getDouble(1) + buffer2.getDouble(1);
buffer1.update(0,mergeSum);
buffer1.update(1,mergeCount);
}
//8、计算出最终结果
public Double evaluate(Row buffer) {
return buffer.getDouble(0)/buffer.getDouble(1);
}
}
调用UDAF,实现
public class RunMyUDAF {
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("RunMyUDAF")
.master("local")
.getOrCreate();
// Register the function to access it
spark.udf().register("myAverage", new MyUDAF());
Dataset df = spark.read().json("src/main/resources/employees.json");
df.createOrReplaceTempView("employees");
df.show();
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 0|
// | Andy| 4537|
// | Justin| 3500|
// | Berta| 0|
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// | Andy| 4500|
// +-------+------+
//保留两位小数,四舍五入
spark.udf().register("twoDecimal", new UDF1() {
@Override
public Double call(Double in) throws Exception {
BigDecimal b = new BigDecimal(in);
double res = b.setScale(2,BigDecimal.ROUND_HALF_DOWN).doubleValue();
return res;
}
}, DataTypes.DoubleType);
Dataset result = spark
.sql("SELECT name,twoDecimal(myAverage(salary)) as avg_salary FROM employees group by name");
result.show();
// +-------+--------------+
// | name| avg_salary |
// +-------+--------------+
// |Michael| 1500.0|
// | Andy| 4512.33|
// | Berta| 2000.0|
// | Justin| 3500.0|
// +-------+--------------+
spark.stop();
}
}