SPark学习笔记:11 SparkSQL 的用户自定义函数UDF、UDAF、UDTF

文章目录

  • UDF 用户自定义函数(一对一)
    • 说明
    • 使用
    • 实现方式
    • 完整示例
  • UDAF 用户自定义聚合函数(多对一)
    • 说明
    • 使用
    • 实现方式
  • UDTF 用户自定义表函数(一对多)
    • 说明:
    • 实现

UDF 用户自定义函数(一对一)

说明

UDF输入一条记录,输出一条记录,一对一的关系,有点类似于map算子,是一对一的关系

使用

UDF的使用有两种方式,一种方式是在SQL中使用,另一种方式是在DSL方式使用

  • 使用SQL的方式
val myconcat3 = sparkSession.udf.register("myconcat3",new MyConcat)

sparkSession.sql("select myconcat3(id,timestamp) as newid,temperature")
  • 在DSL中使用
val myconcat3 = sparkSession.udf.register("myconcat3",new MyConcat)

df.select(myconcat3($"id",$"timestamp") as "newid",$"temperature").show()

实现方式

UDF函数有3中实现方式:

  • 使用匿名函数
val myconcat = sparkSession.udf.register("myconcat",(data1:String,data2:Long)=>{
  data1.concat(data2.toString)
})
  • 使用udf函数实现
//引入udf方法
import org.apache.spark.sql.functions.{col, udf}
val myconcat2 = udf[String,String,Long]((data1:String,data2:Long)=>{
  data1.concat(data2.toString)
})

说明: udf函数可以有多个输入参数,如上述我们实现的是两个输入参数则udf的原型是udf[R,T1,T2]

R表示返回值的类型,T1表示第一个参数的类型,T2表示第2个参数的类型
  • 继承Function函数接口的方式
class MyConcat extends Function2[String,Long,String] with Serializable{
override def apply(v1: String, v2: Long): String = {
  v1.concat(v2.toString)
}
}

说明:Function2[T1,T2,R]

  • 因为我们要实现的UDF有两个输入参数,所以需要继承的是Function2
  • T1 表示第一个参数的类型,T2表示第二个参数的类型
  • R表示是返回值的类型
  • 继承Function Trait的时候,还需要继承Serializable接口,不然会报错。

完整示例

package com.hjt.yxh.hw.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, udf}

object UDFApp {
  def main(args: Array[String]): Unit = {
    val conf:SparkConf = new SparkConf()
    conf.setMaster("local[*]").setAppName("UDFApp")

    val sparkSession:SparkSession = SparkSession.builder()
      .config(conf)
      .config("spark.sql.legacy.charVarcharAsString",true)
      .getOrCreate()

    import sparkSession.implicits._

    val inpath = "D:\\javaworkspace\\BigData\\Spark\\SparkApp\\src\\main\\resources\\sensor.txt"
    val df = sparkSession.read
      .format("csv")
      .schema("id VARCHAR(32),timestamp BIGINT,temperature DECIMAL(5,2)")
      .load(inpath)

    df.show(false)

    //方式一,使用匿名函数实现
    val myconcat = sparkSession.udf.register("myconcat",(data1:String,data2:Long)=>{
      data1.concat(data2.toString)
    })
    df.select(myconcat($"id",$"timestamp") as "newid",$"temperature").show()

    //方式二、使用udf函数实现
    val myconcat2 = udf[String,String,Long]((data1:String,data2:Long)=>{
      data1.concat(data2.toString)
    })

    sparkSession.udf.register("myconcat2",myconcat2)
    df.select(myconcat2($"id",$"timestamp") as "newid",$"temperature").show()

    //方式三、继承类的方式
    val myconcat3 = sparkSession.udf.register("myconcat3",new MyConcat)

    df.select(myconcat3($"id",$"timestamp") as "newid",$"temperature").show()

    sparkSession.close()
  }

  class MyConcat extends Function2[String,Long,String] with Serializable{
    override def apply(v1: String, v2: Long): String = {
      v1.concat(v2.toString)
    }
  }
}

Tips:我们在指定schema中使用VARCHAR类型时会报错,因为spark默认是没有开启VARCHAR或者CHAR类型支持的,需要设置一下参数spark.sql.legacy.charVarcharAsString为true

val sparkSession:SparkSession = SparkSession.builder()
  .config(conf)
  .config("spark.sql.legacy.charVarcharAsString",true)
  .getOrCreate()

UDAF 用户自定义聚合函数(多对一)

说明

UDAF是用户自定义聚合函数,一次输入多行做聚合运算输出一个聚合值作为输出结果。
image

使用

UDF的使用有两种方式,一种方式是在SQL中使用,另一种方式是在DSL方式使用

  • 使用SQL的方式
val myavg = udaf(new MyAggregator)
sparkSession.udf.register("myavg",myavg)
sparkSession.sql("select id,myavg(temperature) from sensor group by id ").show()
  • 在DSL中使用
val myavg = udaf(new MyAggregator)
sparkSession.udf.register("myavg",myavg)

ds.groupBy("id").agg(myavg($"temperature")).show()


实现方式

  • 方式一:实现自定义Aggregator的方式,Spark3.0中是官方推荐的实现方式

自定义Aggregator的方式需要继承实现Aggregator[IN,BUF,OUT]类。

* IN 输入到聚合运算中的数据类型
* BUF 聚合缓冲区的数据类型
* OUT 结果输出数据类型

弱类型的UDAF的实现

case class AggBuffer(var count:Long,var sum:Double)

class MyAggregator extends Aggregator[Double,AggBuffer,Double]{

  //初始化buffer
  override def zero: AggBuffer = {
    AggBuffer(0,0)
  }

  //reduce操作
  override def reduce(b: AggBuffer, a: Double): AggBuffer = {
    b.count+=1
    b.sum+=a
    b
  }

  //合并多个buffer
  override def merge(b1: AggBuffer, b2: AggBuffer): AggBuffer = {
    b1.sum+= b2.sum
    b1.count += b2.count
    b1
  }

  //取结果输出
  override def finish(reduction: AggBuffer): Double = {
    reduction.sum/reduction.count
  }

  //buffer的序列化器(对于自定义的case类是固定写法,Encoders.product)
  override def bufferEncoder: Encoder[AggBuffer] = Encoders.product

  //输出的序列化器(对于自定义的case类是固定写法,Encoders.product)
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
package com.hjt.yxh.hw.sparksql

import com.hjt.yxh.hw.transmate.SensorReading
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Dataset, Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions.{avg, udaf}


case class AggBuffer(var count:Long,var sum:Double)

class MyAggregator extends Aggregator[Double,AggBuffer,Double]{

  //初始化buffer
  override def zero: AggBuffer = {
    AggBuffer(0,0)
  }

  //reduce操作
  override def reduce(b: AggBuffer, a: Double): AggBuffer = {
    b.count+=1
    b.sum+=a
    b
  }

  //合并多个buffer
  override def merge(b1: AggBuffer, b2: AggBuffer): AggBuffer = {
    b1.sum+= b2.sum
    b1.count += b2.count
    b1
  }

  //取结果输出
  override def finish(reduction: AggBuffer): Double = {
    reduction.sum/reduction.count
  }

  //buffer的序列化器(对于自定义的case类是固定写法,Encoders.product)
  override def bufferEncoder: Encoder[AggBuffer] = Encoders.product

  //输出的序列化器(对于自定义的case类是固定写法,Encoders.product)
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

object UDAFApp {

  def main(args: Array[String]): Unit = {
    val conf:SparkConf = new SparkConf()
    conf.setMaster("local[*]")
      .setAppName("UDAFApp")

    val sparkSession:SparkSession = SparkSession.builder()
      .config(conf)
      .config("spark.sql.legacy.charVarcharAsString",true)
      .getOrCreate()
    import sparkSession.implicits._

    val inpath = "D:\\javaworkspace\\BigData\\Spark\\SparkApp\\src\\main\\resources\\sensor.txt"

    val df = sparkSession.read.format("CSV")
      .schema("id VARCHAR(32), timestamp BIGINT,temperature Double")
      .load(inpath)

    val ds:Dataset[SensorReading] = df.as[SensorReading]
    ds.createOrReplaceTempView("sensor")

    ds.show()
    val myavg = udaf(new MyAggregator)
    sparkSession.udf.register("myavg",myavg)

    ds.groupBy("id").agg(myavg($"temperature")).show()

    sparkSession.stop()

  }
}

  • 强类型的UDAF实现示例
package com.hjt.yxh.hw.sparksql

import org.apache.spark.SparkConf
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.Aggregator


case class AggBuffer(var count:Long,var sum:Double)
case class SensorReading(id:String,timestamp:BigInt,temperature:Double)

class MyAggregator extends Aggregator[SensorReading,AggBuffer,Double]{

  //初始化buffer
  override def zero: AggBuffer = {
    AggBuffer(0,0)
  }

  //reduce操作
  override def reduce(b: AggBuffer, a: SensorReading): AggBuffer = {
    b.count+=1
    b.sum+=a.temperature
    b
  }

  //合并多个buffer
  override def merge(b1: AggBuffer, b2: AggBuffer): AggBuffer = {
    b1.sum+= b2.sum
    b1.count += b2.count
    b1
  }

  //取结果输出
  override def finish(reduction: AggBuffer): Double = {
    reduction.sum/reduction.count
  }

  //buffer的序列化器(对于自定义的case类是固定写法,Encoders.product)
  override def bufferEncoder: Encoder[AggBuffer] = Encoders.product

  //输出的序列化器(对于自定义的case类是固定写法,Encoders.product)
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

object UDAFApp {

  def main(args: Array[String]): Unit = {
    val conf:SparkConf = new SparkConf()
    conf.setMaster("local[*]")
      .setAppName("UDAFApp")

    val sparkSession:SparkSession = SparkSession.builder()
      .config(conf)
      .config("spark.sql.legacy.charVarcharAsString",true)
      .getOrCreate()
    import sparkSession.implicits._

    val inpath = "D:\\javaworkspace\\BigData\\Spark\\SparkApp\\src\\main\\resources\\sensor.txt"

    val df = sparkSession.read.format("CSV")
      .schema("id VARCHAR(32), timestamp BIGINT,temperature Double")
      .load(inpath)

    val ds:Dataset[SensorReading] = df.as[SensorReading]
    ds.createOrReplaceTempView("sensor")

    val myavg = new MyAggregator().toColumn.name("avg_temperature")
    sparkSession.udf.register("myavg",functions.udaf(new MyAggregator()))

    val ret = ds.select(myavg)

    sparkSession.stop()

  }
}
  • 方式二 继承UserDefinedAggregateFunction的方式(在新的版本中已移除)
class MyAvg extends UserDefinedAggregateFunction{

  //定义输入的数据结构
  override def inputSchema: StructType = new StructType()
    .add("temperature","double")


  //定义缓冲区的数据结构
  override def bufferSchema: StructType = new StructType()
    .add("count",LongType)
    .add("sum",DoubleType)

  //定义输出的数据类型
  override def dataType: DataType = DataTypes.DoubleType


  override def deterministic: Boolean = true

  //初始化缓冲区
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0,0L)
    buffer.update(1,0.0)
  }

  //做更新缓冲区操作,相当用reduce
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    var count = buffer.getLong(0)
    var sum = buffer.getDouble(1)
    count+= 1
    sum+= input.getDouble(0)
    buffer.update(0,count)
    buffer.update(1,sum)
  }
  //合并两个缓存去
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0,buffer1.getLong(0)+buffer2.getLong(0))
    buffer1.update(1,buffer1.getDouble(1)+buffer2.getDouble(1))
  }

  //计算结果,并输出
  override def evaluate(buffer: Row): Any = {
    buffer.getDouble(1)/buffer.getLong(0)
  }
}

UDTF 用户自定义表函数(一对多)

说明:

SparkSQL本身其实是没有提供UDTF函数功能的,需要启用Hive支持的方式才能使用。

实现

class MySplit extends GenericUDTF {

  override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
    //判断传入的参数是否只有一个
    if (argOIs.length != 1) {
      throw new UDFArgumentException("有且只能有一个参数")
    }
    //判断参数类型
    if (argOIs(0).getCategory != ObjectInspector.Category.PRIMITIVE) {
      throw new UDFArgumentException("参数类型不匹配")
    }
    val fieldNames = new util.ArrayList[String]
    val fieldOIs = new util.ArrayList[ObjectInspector]
    fieldNames.add("type")
    fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
    ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
  }
    
  override def process(objects: Array[AnyRef]): Unit = {
    val data_list:Array[String] = objects(0).toString.split("_")
    //遍历集合
    for(item <- data_list){
      val temp = new Array[String](1)
      temp(0)=item
      forward(temp)
    }
  }

  override def close(): Unit = {}
}


object UDTFApp {

  def main(args: Array[String]): Unit = {
    val conf:SparkConf = new SparkConf()
    conf.setMaster("local[*]").setAppName("UDTFTest")

    val sparkSession:SparkSession = SparkSession.builder().enableHiveSupport().config(conf).getOrCreate()
    import sparkSession.implicits._

    val inpath:String = "D:\\javaworkspace\\BigData\\Spark\\SparkApp\\src\\main\\resources\\sensor.txt"
    val df1 = sparkSession.read.format("csv")
      .schema("id String,timestamp Bigint,temperature double")
      .load(inpath)
    
    //注册表
    df1.createOrReplaceTempView("sensor")
    
    //创建函数,使用类名的方式
    sparkSession.sql("create TEMPORARY function mySplit as 'com.hjt.yxh.hw.sql.MySplit'")

    sparkSession.sql("select mySplit(id) from sensor").show()

    sparkSession.stop()
  }

}

你可能感兴趣的:(Spark,大数据,spark,学习,scala)