SparkSQL中的UDF、UDAF、UDTF实现

分类

根据输入输出之间的关系来分类:

UDF —— 输入一行,输出一行
UDAF —— 输入多行,输出一行
UDTF —— 输入一行,输出多行

UDF函数

1、数据

大狗	三国,水浒,红楼
二狗	金瓶梅
二条	西游,唐诗宋词

2、需求:求出每个人的爱好个数
3、实现

def main(args: Array[String]): Unit = {
     
    val spark = SparkSession.builder
      .master("local")
      .appName(this.getClass.getSimpleName)
      .getOrCreate()
    import spark.implicits._
    val df = spark.sparkContext.textFile("D:\\ssc\\likes.txt")
        .map(_.split("\t"))
        .map(x=>Likes(x(0),x(1))).toDF()
    df.createOrReplaceTempView("t_team")
    val teamsLengthUTF = spark.udf.register("teams_length",(input:String)=>{
     
      input.split(",").length
    })
    println("--------------SQL方式----------------")
    spark.sql("select name,teams,teams_length(teams) as teams_length from t_team").show(false)
    println("--------------API方式----------------")
    df.select($"name",$"teams",teamsLengthUTF($"teams").as("teams_length")).show(false)
    spark.stop()
  }
  case class Likes(name:String,teams:String)
}

UDAF

如果自定义UDAF需要继承UserDefinedAggregateFunction

SparkSql自带的聚和函数:

 /**
    * 多进一出 udaf
    * @param spark
    */
  def udafWithSum(spark:SparkSession): Unit = {
     
    val rows = new util.ArrayList[Row]()
    rows.add(Row("Luck",30,"M"))
    rows.add(Row("Jack",60,"M"))
    rows.add(Row("Jim",19,"F"))
    rows.add(Row("Lily",20,"F"))

    val schema = StructType(
      List(
        StructField("name",StringType,false),
        StructField("age",IntegerType,false),
        StructField("sex",StringType,false)
      )
    )

    val df = spark.createDataFrame(rows,schema)

    df.createOrReplaceTempView("user")

    spark.sql("select sex,sum(age) from user group by sex").show(false)
  }

自定义UDAF
需求:男性和女性各自的平均年龄

  //avg = sum/参与计算的个数
  object JimAvgUDAF extends UserDefinedAggregateFunction {
     
    /**
      * 输入类型
      * @return
      */
    override def inputSchema: StructType = {
     
      StructType(
        StructField("nums",DoubleType,true) :: Nil
      )
    }

    /**
      * 定义聚合过程中所处理的数据类型
      * @return
      */
    override def bufferSchema: StructType = {
     
      StructType(
        StructField("buffer1",DoubleType,true) ::     //年龄总和
          StructField("buffer2",LongType,true) :: Nil //参与计算的人数
      )
    }

    /**
      * 输出数据类型
      * @return
      */
    override def dataType: DataType = DoubleType

    /**
      * 规定一致性
      * @return
      */
    override def deterministic: Boolean = true

    /**
      * 初始化数据
      * @param buffer
      */
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
     
      buffer.update(0,0.0)
      buffer.update(1,0L)
    }

    /**
      * 分区内聚和
      * @param buffer
      * @param input
      */
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
     
      buffer.update(0,buffer.getDouble(0) + input.getDouble(0))
      buffer.update(1,buffer.getLong(1) + 1)
    }

    /**
      * 全局聚合
      * @param buffer1
      * @param buffer2
      */
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
     
      buffer1.update(0,buffer1.getDouble(0) + buffer2.getDouble(0))
      buffer1.update(1,buffer1.getLong(1) + buffer2.getLong(1))
    }

    /**
      * 最终结果
      * @param buffer
      * @return
      */
    override def evaluate(buffer: Row): Any = {
     
      buffer.getDouble(0) / buffer.getLong(1)
    }
  }

调用

/**
    * select sex,jim_avg(age) from user group by sex
    * @param spark
    */
  def udafAvgWithSex(spark:SparkSession): Unit = {
     
    val rows = new util.ArrayList[Row]()
    rows.add(Row("Luck",30,"M"))
    rows.add(Row("Jack",60,"M"))
    rows.add(Row("Jim",19,"F"))
    rows.add(Row("Lily",20,"F"))

    val schema = StructType(
      List(
        StructField("name",StringType,false),
        StructField("age",IntegerType,false),
        StructField("sex",StringType,false)
      )
    )

    val df = spark.createDataFrame(rows,schema)

    df.createOrReplaceTempView("user")

    spark.udf.register("avg_udaf",JimAvgUDAF)

    spark.sql("select sex,avg_udaf(age) as avg_age from user group by sex").show(false)
  }

UDTF

一进多出

def udtfExplode(spark:SparkSession): Unit = {
     
    val rows = new util.ArrayList[Row]()
    rows.add(Row("Luck","Java,JavaScript,Scala"))
    rows.add(Row("Jack","History,English,Math"))

    val schema = StructType(
      List(
        StructField("teacher",StringType,false),
        StructField("courses",StringType,false)
      )
    )

    val df = spark.createDataFrame(rows,schema)

//    implicit val encoder = org.apache.spark.sql.Encoders.kryo[(String,String)]

    import spark.implicits._

    val courseDs = df.flatMap(row => {
     
      val list = new ListBuffer[Course]()
      val courses = row.getString(1).split(",")
      for (course <- courses) {
     
        list.append(Course(row.getString(0), course))
      }
      list
    })

    courseDs.printSchema()

    courseDs.show(false)
  }

  case class Course(teacher:String,course: String)

你可能感兴趣的:(sparksql)