SparkSql------自定义函数UDF和UDAF

UDF

测试数据 

{"name":"aaa", "age":20}
{"name":"bbb", "age":30, "facevalue":80}
{"name":"ccc", "age":28, "facevalue":80}
{"name":"ddd", "age":28, "facevalue":90}

案例(scala语言)

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

object UDFFunction {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("udf").setMaster("local")
    val spark = SparkSession.builder().config(conf).getOrCreate()

    val df = spark.read.json("people.json")
    df.createOrReplaceTempView("people")

    //创建自定义函数
    spark.udf.register("addName",(x:String)=>"Name:"+x)
    spark.sql("select addName(name),age from people").show()
  }
//    +-----------------+---+
//    |UDF:addName(name)|age|
//    +-----------------+---+
//    |         Name:aaa| 20|
//    |         Name:bbb| 30|
//    |         Name:ccc| 28|
//    |         Name:ddd| 28|
//    +-----------------+---+
}

UDAF

下面展示一个求平均工资的自定义聚合函数。

通过继承UserDefinedAggregateFunction来实现用户自定义聚合函数。

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
//自定义UDAF函数
class MyAverage extends UserDefinedAggregateFunction {
  // 输入数据
  def inputSchema: StructType = StructType(List(StructField("Salary",DoubleType,true)))
  // 每一个分区中的 共享变量 存储记录的值
  def bufferSchema: StructType = {
    //                     工资的总和                      工资的总数
    StructType(StructField("sum", DoubleType):: StructField("count", DoubleType)  :: Nil)
  }
  // 返回值的数据类型表示UDAF函数的输出类型
  def dataType: DataType = DoubleType

  //如果有相同的输入,那么是否UDAF函数有相同的输出,有true 否则false
  //UDAF函数中如果输入的数据掺杂着时间,不同时间得到的结果可能是不一样的所以这个值可以设置为false
  //若不掺杂时间,这个值可以输入为true
  def deterministic: Boolean = true

  // 初始化对Buffer中的属性初始化即初始化分区中每一个共享变量
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    // 存工资的总额
    buffer(0) = 0.0//取得就是sum
    // 存工资的个数
    buffer(1) = 0.0//取得就是count
  }
  // 相同Execute间的数据合并,合并小聚合中的数据即每一个分区中的每一条数据聚合的时候需要调用的方法
  /*
   第一个参数buffer还是共享变量
   第二个参数是一行数据即读取到的数据是以一行来看待的
   */
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if (!input.isNullAt(0)) {
      //获取这一行中的工资,然后将工资添加到该方法中
      buffer(0) = buffer.getDouble(0) + input.getDouble(0)
      //将工资的个数进行加1操作最终是为了计算所有的工资的个数
     buffer(1) = buffer.getDouble(1) + 1
    }
  }
  // 不同Execute间的数据合并,合并大数据中的数即将每一个区分的输出合并形成最后的数据
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    //合并总的工资
    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
    //合并总的工资个数
    buffer1(1) = buffer1.getDouble(1) + buffer2.getDouble(1)
  }
  // 计算最终结果
  def evaluate(buffer: Row): Double = buffer.getDouble(0) / buffer.getDouble(1)
}
// 需求:统计员工平均薪资
object MyAverage{
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("MyAverage").master("local[*]").getOrCreate()
    // 注册函数
    spark.udf.register("myAverage",new MyAverage)

    val df = spark.read.json("dir/employees.json")
    df.createOrReplaceTempView("employees")
    df.show()
    //虽然没有使用groupby那么会将整个数据作为一个组
    val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
    result.show()
  }
}

通过继承Aggregator来实现强类型自定义聚合函数,同样是求平均工资。

import org.apache.spark.sql.expressions.{Aggregator}
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.SparkSession
//自定义UDAF函数
// 既然是强类型,可能有case类
case class Employee(name: String, salary: Double)
case class Average(var sum: Double, var count: Double)
//依次配置输入,共享变量,输出的类型,需要使用到case class
class MyAverage extends Aggregator[Employee, Average, Double] {
  // 初始化方法 初始化每一个分区中的 共享变量即定义一个数据结构,保存工资总数和工资总个数,初始都为0
  def zero: Average = Average(0.0, 0.0)
  //每一个分区中的每一条数据聚合的时候需要调用该方法
  def reduce(buffer: Average, employee: Employee): Average = {
    buffer.sum += employee.salary
    buffer.count += 1
    buffer
  }
  //将每一个分区的输出 合并 形成最后的数据
  def merge(b1: Average, b2: Average): Average = {
    b1.sum += b2.sum
    b1.count += b2.count
    b1
  }
  // 给出计算结果
  def finish(reduction: Average): Double = reduction.sum / reduction.count
  // 设定中间值类型的编码器,要转换成case类
  // Encoders.product是进行scala元组和case类转换的编码器
  def bufferEncoder: Encoder[Average] = Encoders.product
  // 设定最终输出值的编码器
  def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

object MyAverage{
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("MyAverage").master("local[*]").getOrCreate()
    import spark.implicits._
    val ds = spark.read.json("dir/employees.json").as[Employee]
    ds.show()
    val averageSalary = new MyAverage().toColumn.name("average_salary")
    val result = ds.select(averageSalary)
    result.show()
  }
}

 

你可能感兴趣的:(spark基础)