Spark-用户定义函数

用户可以调用SparkSession的udf功能自定义函数

用户定义函数

  1. 加载json数据

    val df = spark.read.json("files\\test.json")
    
  2. 注册用户定义函数

    spark.udf.register("addName", (name: String) => "Name:" + name)
    
  3. 创建视图并查询

    df.createOrReplaceTempView("test")
    val testDF = spark.sql("select addName(name), name from test")
    testDF.show()
    /*
    +-----------------+----+
    |UDF:addName(name)|name|
    +-----------------+----+
    |        Name:adam|adam|
    |        Name:brad|brad|
    |        Name:carl|carl|
    +-----------------+----+
    */
    

用户定义聚合函数(弱类型)

弱类型用户定义聚合函数通过继承UserDefinedAggregateFunction来实现用户自定义聚合函数。

需要实现如下方法:

  1. inputSchema:函数输入的数据结构
  2. bufferSchema: 计算过程中缓存的数据结构
  3. dataType:函数返回的数据类型
  4. deterministic:函数是否稳定
  5. initialize:计算前缓冲区的初始化
  6. update:更新缓冲区数据
  7. merge:合并缓冲区数据
  8. evaluate:计算结果

计算平均年龄,例子如下:

class AvgAge extends UserDefinedAggregateFunction {
  override def inputSchema: StructType = new StructType().add("age", LongType)

  override def bufferSchema: StructType = new StructType().add("sum", LongType).add("count", LongType)

  override def dataType: DataType = DoubleType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 0L
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getLong(0) + input.getLong(0)
    buffer(1) = buffer.getLong(1) + 1L
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  override def evaluate(buffer: Row): Any = (buffer.getLong(0) / buffer.getLong(1)).toDouble
}

使用定义好的聚合函数进行计算:

object UdafDemo {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[*]").appName("UdafDemo").getOrCreate()
    
    val df = spark.read.json("files\\test.json")
    df.createOrReplaceTempView("test")
    
    val avgAge = new AvgAge()
    spark.udf.register("avgAge", avgAge)
    
    val avgAgeDf = spark.sql("select avgAge(age) from test")
    
    avgAgeDf.show()
  }
}
/*
+-----------+
|avgage(age)|
+-----------+
|       17.0|
+-----------+
*/

用户定义聚合函数(强类型)

通过继承Aggregator[IN, BUF, OUT]类可自定义强类型的聚合函数。

需要实现如下方法:

  1. zero:初始化缓冲区
  2. reduce:更新缓冲区
  3. merge:合并缓冲区
  4. finish:计算结果
  5. bufferEncoder:缓冲区编码器
  6. outputEncoder:输出编码器

注:编码器在Encoders类中可以找到不同类型的实现,对于自定义对象则选择Encoders.product,而基本数据类型则选择对应的编码器,例如Double类型的选择Encoders.scalaDouble

计算平均年龄,例子如下:

  1. 定义Person样例类

    case class Person(id: Long, name: String, age: Long)
    
  2. 定义缓冲区样例类

    case class AvgBuffer(sum: Long, count: Int)
    
  3. 实现Aggregator

    class MyAvgAge extends Aggregator[Person, AvgBuffer, Double] {
    
      override def zero: AvgBuffer = AvgBuffer(0, 0)
    
      override def reduce(b: AvgBuffer, a: Person): AvgBuffer = {
        b.sum += a.age
        b.count += + 1
        b
      }
    
      override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = {
        b1.sum += b2.sum
        b1.count += b2.count
        b1
      }
    
      override def finish(reduction: AvgBuffer): Double = reduction.sum.toDouble / reduction.count
    
      override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product
    
      override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
    
    }
    
  4. 使用自定义函数计算平均年龄

    object Udaf1Demo {
      def main(args: Array[String]): Unit = {
        val spark = SparkSession.builder().master("local[*]").appName("Udaf1Demo").getOrCreate()
        import spark.implicits._
    
        val df = spark.read.json("files\\test.json")
    
        val avgAge = new MyAvgAge
        val avgCol = avgAge.toColumn.name("avgAge")
    
        val ds = df.as[Person]
        ds.select(avgCol).show()
    
      }
    }
    /*
    +------------------+
    |            avgAge|
    +------------------+
    |17.333333333333332|
    +------------------+
    */
    

你可能感兴趣的:(Spark-用户定义函数)