SPARK-SQL-之UDF、UDAF

SPARK-SQL-之UDF、UDAF

1、UDF使用

// 注册函数    
spark.udf.register("prefix1", (name: String) => {
    "Name:" + name
})
// 使用函数
spark.sql("select *,prefix1(name) from users").show()

2、UDAF使用

2.1 弱类型

// 1 定义UDAF(弱类型、3.0.0之前得版本可以使用,没标记过时)
package com.shufang.rdd_ds_df

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}

class MyUDAF extends UserDefinedAggregateFunction {
  // IN
  override def inputSchema: StructType = {
    StructType(
      Array(
        StructField("age", LongType)
      )
    )
  }

  // MIDDLE 缓冲区类型
  override def bufferSchema: StructType = {
    StructType(
      Array(
        StructField("total", LongType),
        StructField("count", LongType)
      )
    )
  }

  // OUT
  override def dataType: DataType = LongType

  // 函数的稳定性
  override def deterministic: Boolean = {
    true
  }

  // 缓冲器的初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    /*buffer(0) = 0L
    buffer(1) = 0L*/
    buffer.update(0, 0L)
    buffer.update(1, 0L)
  }

  // 根据输入的值更新缓冲区
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer.update(0, buffer.getLong(0) + input.getLong(0))
    buffer.update(1, buffer.getLong(1) + 1)
  }

  // 合并多个缓冲区
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0,buffer1.getLong(0) + buffer2.getLong(0))
    buffer1.update(1,buffer1.getLong(1) + buffer2.getLong(1))
  }

  // 计算平均值
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0)/buffer.getLong(1)
  }
}


// 2 注册&使用
spark.udf.register("ageAvg", new MyUDAF)
spark.sql("select ageAvg(id) as av from users").show()

2.2 强类型(spark 3.0.0之后推荐使用)

// 1 声明并实现
package com.shufang.rdd_ds_df

import org.apache.spark.sql.{Encoder, Encoders, Row}
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}

/**
 * Aggregator[IN, BUF, OUT] should now be registered as a UDF" + via the functions.udaf(agg) method.", "3.0.0"
 */
case class Buff(var total:Long ,var count:Long)
class MyUDAF1 extends Aggregator[Long,Buff,Long] {
  //缓冲区初始化
  override def zero: Buff = Buff(0L,0L)
  //将进来的元素与缓冲区进行合并
  override def reduce(b: Buff, a: Long): Buff = {
    b.count +=1
    b.total += a
    b
  }
  //合并多个缓冲区
  override def merge(b1: Buff, b2: Buff): Buff = {
    b1.count  = b1.count + b2.count
    b1.total  = b1.total + b2.total
    b1
  }

 // 计算最终结果
  override def finish(buff: Buff): Long = {
    buff.total/buff.count
  }

 // 定义序列化编码器
  override def bufferEncoder: Encoder[Buff] = Encoders.product
 //定义序列化编码器
  override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}


// 2 注册并使用,注册方式不一样
spark.udf.register("ageAvg", functions.udaf(new MyUDAF1()))
spark.sql("select ageAvg(id) as av from users").show()
 

2.3 早期版本使用强类型UDAF

如果是3.0.0之前的版本需要使用强类型,需要结合DSL sparkSQL的领域语言

// 1 声明,相当于DS的每一行相当于传入的参数
package com.shufang.rdd_ds_df

    import org.apache.spark.sql.{Encoder, Encoders, Row}
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}

/**
 * Aggregator[IN, BUF, OUT] should now be registered as a UDF" + via the functions.udaf(agg) method.", "3.0.0"
 */
//case class Buff(var total:Long ,var count:Long)
class MyUDAF2 extends Aggregator[User,Buff,Long] {
    //缓冲区初始化
    override def zero: Buff = Buff(0L,0L)

        override def reduce(b: Buff, a: User): Buff = {
        b.count +=1
            b.total += a.id
            b
    }

    override def merge(b1: Buff, b2: Buff): Buff = {
        b1.count  = b1.count + b2.count
            b1.total  = b1.total + b2.total
            b1
    }

    override def finish(buff: Buff): Long = {
        buff.total/buff.count
    }

    override def bufferEncoder: Encoder[Buff] = Encoders.product

        override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}

// 2 使用
val column: TypedColumn[User, Long] = new MyUDAF2().toColumn
val ds: Dataset[User] = df.as[User]
ds.select(column).show()

你可能感兴趣的:(SparkSQL,spark,sql,scala)