SparkSQL-自定义函数

用户自定义UDF函数

#创建DataFrame
scala> val df = spark.read.json("..../user.json")
df: org.apache.spark.sql.DataFrame = [age: bigint, name: string]

scala> df.show()
+----+-------+
| age|   name|
+----+-------+
|null|Michael|
|  30|   Andy|
|  19| Justin|
+----+-------+

#addName 就是自定义函数的名字,x=>String 就是查询的参数,比如查询语句中的name
scala> spark.udf.register("addName", (x:String)=> "Name:"+x)
res5: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(,StringType,Some(List(StringType)))

#创建临时表
scala> df.createOrReplaceTempView("people")

#使用自定义函数进行查询
scala> spark.sql("select addName(name), age from people").show()
+-----------------+----+
|UDF:addName(name)| age|
+-----------------+----+
|     Name:Michael|null|
|        Name:Andy|  30|
|      Name:Justin|  19|
+-----------------+----+

用户自定义聚合函数(UFAF)

强类型的Dataset和弱类型的DataFrame都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数。

弱类型聚合函数

弱类型用户自定义聚合函数:通过继承UserDefinedAggregateFunction来实现用户自定义聚合函数。下面展示一个求平均工资的自定义聚合函数。

user.json:

{
     "name":"1","age":20}
{
     "name":"2","age":30}
{
     "name":"3","age":40}

代码:

import org.apache.spark
import org.apache.spark.sql.expressions.{
     MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{
     DataType, DoubleType, LongType, StructType}
import org.apache.spark.{
     SparkConf, sql}
import org.apache.spark.sql.{
     DataFrame, Row, SparkSession}

object SparkSqlTest {
     
  def main(args: Array[String]): Unit = {
     
    val sparkConf = new SparkConf().setMaster("local").setAppName("SparkSql")
    val spark: SparkSession = new sql.SparkSession.Builder()
//      .appName("SparkSql")
//      .master("local[*]")
      .config(sparkConf)
      .getOrCreate()
    //自定义聚合函数
    //创建自定义聚合函数
    val udaf = new MyAgeAvgFunction

    //注册聚合函数
    spark.udf.register("avgAge", udaf)

    //使用聚合函数
    val df: DataFrame = spark.read.json("in/user.json")
    df.createOrReplaceTempView("user")
    spark.sql("select avgAge(age) from user").show()

    //释放资源
    spark.stop()
  }
}
/*
 用户自定义聚合函数
 1.继承UserDefinedAggregateFunction
 2.实现方法
 */
class MyAgeAvgFunction extends UserDefinedAggregateFunction{
     

  //函数输入数据结构
  override def inputSchema: StructType = {
     
    //当前输入的数据 age,类型为LongType
    new StructType().add("age", LongType)
  }

  //计算时的数据结构
  override def bufferSchema: StructType = {
     
    //当前输入的数据 age,类型为LongType
    new StructType().add("sum", LongType).add("count", LongType)
  }

  //函数返回的数据类型
  override def dataType: DataType = DoubleType

  //函数稳定性
  override def deterministic: Boolean = true

  //计算之前缓冲区的初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
     
    buffer(0) = 0L//将sum初始化为0,L是指float类型
    buffer(1) = 0L//将count初始化为0,L是指float类型
  }

  //根据查询结果更新缓冲区数据
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
     
    buffer(0) = buffer.getLong(0) + input.getLong(0)
    buffer(1) = buffer.getLong(1) + 1

  }

  //将多个结点的缓冲区合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
     
    //sum
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    //count
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  //计算
  override def evaluate(buffer: Row): Any = {
     
    buffer.getLong(0).toDouble / buffer.getLong(1).toDouble
  }
}

强类型聚合函数

强类型用户自定义聚合函数:通过继承Aggregator来实现强类型自定义聚合函数
import org.apache.spark.sql.expressions.Aggregator

import org.apache.spark.{
     Aggregator, SparkConf, sql}
import org.apache.spark.sql.{
     DataFrame, Row, SparkSession}
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.Aggregator

object SparkSqlTest {
     
  def main(args: Array[String]): Unit = {
     
    val sparkConf = new SparkConf().setMaster("local").setAppName("SparkSql")
    val spark: SparkSession = new sql.SparkSession.Builder()
//      .appName("SparkSql")
//      .master("local[*]")
      .config(sparkConf)
      .getOrCreate()

    import spark.implicits._

    //创建自定义聚合函数,注意聚合函数类的参数
    val udaf = new MyAgeAvgClassFunction

    //将聚合函数转换为查询列,强类型的自定义函数不能使用sql方法查询,需要使用DSL风格
    val avgCol: TypedColumn[UserBean, Double] = udaf.toColumn.name("AvgAge")

    val df: DataFrame = spark.read.json("in/user.json")
    val userDS: Dataset[UserBean] = df.as[UserBean]//每一条数据都是UserBean

    //使用聚合函数
    userDS.select(avgCol).show

    //释放资源
    spark.stop()
  }
}

//age类型为BigInt, 因为从文件读取的数字默认为BigInt
case class UserBean(name:String, age: BigInt)
//样例类中的属性默认是val的,在后面需要改变,所以将其设置为var类型
case class AvgBuffer(var sum: BigInt,var count: Double)
/*
 用户自定义聚合函数(强类型)
 1.继承UserDefinedAggregateFunction
 2.实现方法
 3.
 4.
 */
class MyAgeAvgClassFunction extends Aggregator[ UserBean, AvgBuffer, Double ] {
     
  //初始化
  override def zero: AvgBuffer = {
     
    AvgBuffer(0,0)//不是ArrayBuffer(0,0)
  }

  //聚合数据
  override def reduce(b: AvgBuffer, a: UserBean): AvgBuffer = {
     
    b.sum = b.sum + a.age
    b.count = b.count + 1
    b
  }

  //缓冲区的合并操作
  override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = {
     
    b1.sum = b1.sum + b2.sum
    b1.count = b1.count + b2.count
    b1
  }

  //完成计算
  override def finish(reduction: AvgBuffer): Double = {
     
    //将缓冲区中的数据进行计算
    reduction.sum.toDouble / reduction.count
  }

  //下面这两个几乎固定
  override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product

  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble

你可能感兴趣的:(spark)