SparkSQL 自定义算子UDF、UDAF、UDTF


背景

我根据算子输入输出之间的关系来理解算子分类:

UDF——输入一行,输出一行
UDAF——输入多行,输出一行
UDTF——输入一行,输出多行

本文主要是整理这三种自定义算子的具体实现方式

使用的数据集——用户行为日志user_log.csv,csv中自带首行列头信息,字段定义如下:

1. user_id | 买家id
2. item_id | 商品id
3. cat_id | 商品类别id
4. merchant_id | 卖家id
5. brand_id | 品牌id
6. month | 交易时间:月
7. day | 交易事件:日
8. action | 行为
9. age_range | 买家年龄分段
10. gender | 性别
11. province| 收获地址省份

新手上路,有任何搞错的地方,或者走了弯路,还请大家不吝指出,帮我进步



SparkSQL算子分类

  • 1. UDF
  • 2. UDAF
  • 3. UDTF
  • ► 小结


1. UDF

通过匿名函数的方式注册自定义算子

object UserAnalysis {
  def main(args:Array[String]): Unit ={

    //测试数据所在的本地路径
    val userDataPath = "file:///home/hadoop/data_format/small_user_log.csv"

    //创建sparksession
    val sparkSession = SparkSession
      .builder
      .master("local")
      .appName("UserAnalysis")
      .enableHiveSupport()      //启用hive
      .getOrCreate()

    //sparksession直接读取csv,可设置分隔符delimitor.
    val userDF = sparkSession.read
      .option("header","true")
      .csv(userDataPath)

    //将DataFrame注册成视图,然后即可使用hql访问
    userDF.createOrReplaceTempView("userDF")

    //通过匿名函数的方式注册自定义算子:将0和1分别转换成female和male
    sparkSession.udf.register("getGender",(gender:Integer)=>{
      var result="unknown"
      if (gender==0){
        result="female"
      }else if(gender==1){
        result="male"
      }
      result
    })

    val genderDF = sparkSession.sql("select getGender(gender) as A from userDF")

    //显示DataFrame内容
    genderDF.show(10)
  }
}

通过实名函数的方式注册自定义算子

object UserAnalysis {
  def main(args:Array[String]): Unit ={

    //测试数据所在的本地路径
    val userDataPath = "file:///home/hadoop/data_format/small_user_log.csv"

    //创建sparksession
    val sparkSession = SparkSession
      .builder
      .master("local")
      .appName("UserAnalysis")
      .enableHiveSupport()      //启用hive
      .getOrCreate()

    //sparksession直接读取csv,可设置分隔符delimitor.
    val userDF = sparkSession.read
      .option("header","true") 
      .csv(userDataPath)

    //将DataFrame注册成视图,然后即可使用hql访问
    userDF.createOrReplaceTempView("userDF")

    /*
    通过实名函数的方式注册自定义算子
    Scala中方法和函数是两个不同的概念,方法无法作为参数进行传递,
    也无法赋值给变量,但是函数是可以的。在Scala中,利用下划线可以将方法转换成函数:
    */
    sparkSession.udf.register("getGender",getGender _)

    val genderDF = sparkSession.sql("select getGender(gender) as A from userDF")

    //显示DataFrame内容
    genderDF.show(10)
  }

  //将0和1分别转换成female和male
  def getGender(gender:Integer): String ={
    var result="unknown"
    if (gender==0){
      result="female"
    }else if(gender==1){
      result="male"
    }
    result
  }
}

通过以上两种方式实现相同算子,得到相同的结果:
SparkSQL 自定义算子UDF、UDAF、UDTF_第1张图片

2. UDAF

通过实现抽象类org.apache.spark.sql.expressions.UserDefinedAggregateFunction来自定义UDAF算子

class UserDefinedMax extends UserDefinedAggregateFunction{

  //定义输入数据的类型,两种写法都可以
  //override def inputSchema: StructType = StructType(Array(StructField("input", IntegerType, true)))
  override def inputSchema: StructType = StructType(StructField("input", IntegerType) :: Nil)

  //定义聚合过程中所处理的数据类型
  override def bufferSchema: StructType = StructType(Array(StructField("cache", IntegerType, true)))

  //定义输入数据的类型
  override def dataType: DataType = IntegerType

  //规定一致性
  override def deterministic: Boolean = true

  //在聚合之前,每组数据的初始化操作
  override def initialize(buffer: MutableAggregationBuffer): Unit = {buffer(0) =0}

  //每组数据中,当新的值进来的时候,如何进行聚合值的计算
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if(input.getInt(0)> buffer.getInt(0))
      buffer(0)=input.getInt(0)
    }

  //合并各个分组的结果
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    if(buffer2.getInt(0)> buffer1.getInt(0)){
      buffer1(0)=buffer2.getInt(0)
    }
  }

  //返回最终结果
  override def evaluate(buffer: Row): Any = {buffer.getInt(0)}
}

测试代码

object UserAnalysis {
  def main(args:Array[String]): Unit ={

    //测试数据所在的本地路径
    val userDataPath = "file:///home/hadoop/data_format/small_user_log.csv"

    //创建sparksession
    val sparkSession = SparkSession
      .builder
      .master("local")
      .appName("UserAnalysis")
      .enableHiveSupport()      //启用hive
      .getOrCreate()

    //sparksession直接读取csv,可设置分隔符delimitor.
    var userDF = sparkSession.read
      .option("header","true")
      .csv(userDataPath)

    //转换dataframe字段类型或字段名
    import org.apache.spark.sql.functions._
    userDF = userDF .withColumn("item_id", col("item_id").cast(IntegerType))

    //将DataFrame注册成视图,然后即可使用hql访问
    userDF.createOrReplaceTempView("userDF")

    //注册算子,如果UserDefinedMax是object,不用new
    sparkSession.udf.register("UserDefinedMax", new UserDefinedMax)

    //测试sparksql内嵌max算子结果
    val MaxDF = sparkSession.sql("select max(item_id) from userDF")

    MaxDF.show

    //测试用户自定义max算子结果
    val UserDefinedMaxDF = sparkSession.sql("select UserDefinedMax(item_id) from userDF")

    UserDefinedMaxDF.show
  }
}

可以看到两个max算子的输出相同:
在这里插入图片描述
在这里插入图片描述

3. UDTF

通过实现抽象类org.apache.hadoop.hive.ql.udf.generic.GenericUDTF来自定义UDTF算子

class UserDefinedUDTF extends GenericUDTF{

  //这个方法的作用:1.输入参数校验  2. 输出列定义,可以多于1列,相当于可以生成多行多列数据
  override def initialize(args:Array[ObjectInspector]): StructObjectInspector = {
    if (args.length != 1) {
      throw new UDFArgumentLengthException("UserDefinedUDTF takes only one argument")
    }
    if (args(0).getCategory() != ObjectInspector.Category.PRIMITIVE) {
      throw new UDFArgumentException("UserDefinedUDTF takes string as a parameter")
    }

    val fieldNames = new util.ArrayList[String]
    val fieldOIs = new util.ArrayList[ObjectInspector]

    //这里定义的是输出列默认字段名称
    fieldNames.add("col1")
    //这里定义的是输出列字段类型
    fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)

    ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
  }

  //这是处理数据的方法,入参数组里只有1行数据,即每次调用process方法只处理一行数据
  override def process(args: Array[AnyRef]): Unit = {
    //将字符串切分成单个字符的数组
    val strLst = args(0).toString.split("")
    for(i <- strLst){
      var tmp:Array[String] = new Array[String](1)
      tmp(0) = i
      //调用forward方法,必须传字符串数组,即使只有一个元素
      forward(tmp)
    }
  }

  override def close(): Unit = {}
}

测试代码

object UserAnalysis {
  def main(args:Array[String]): Unit ={

    //测试数据所在的本地路径
    val userDataPath = "file:///home/hadoop/data_format/zxc/small1.csv"

    //创建sparksession
    val sparkSession = SparkSession
      .builder
      .master("local")
      .appName("UserAnalysis")
      .enableHiveSupport()      //启用hive
      .getOrCreate()

    //sparksession直接读取csv,可设置分隔符delimitor.
    var userDF = sparkSession.read
      .option("header","true")
      .csv(userDataPath)

    //将DataFrame注册成视图,然后即可使用hql访问
    userDF.createOrReplaceTempView("userDF")

    //注册utdf算子,这里无法使用sparkSession.udf.register()
    sparkSession.sql("CREATE TEMPORARY FUNCTION UserDefinedUDTF as 'com.zxc.sparkAppTest.udtf.UserDefinedUDTF'")

    //使用UDTF算子处理原表userDF
    val UserDefinedUDTFDF = sparkSession.sql(
      "select " +
          "user_id," +
          "item_id," +
          "cat_id," +
          "merchant_id," +
          "brand_id," +
          "month," +
          "day," +
          "action," +
          "age_range," +
          "gender," +
          "UserDefinedUDTF(province) " +
        "from " +
          "userDF"
    )

    UserDefinedUDTFDF.show
  }
}

对比原表和经UDTF算子处理之后的结果表:
SparkSQL 自定义算子UDF、UDAF、UDTF_第2张图片
SparkSQL 自定义算子UDF、UDAF、UDTF_第3张图片

► 小结

  • 关于UDF
    简单粗暴的理解,它就是输入一行输出一行的自定义算子
    我们可以通过实名函数或匿名函数的方式来实现,并使用sparkSession.udf.register()注册
    需要注意,截至目前(spark2.4)最多只支持22个输入参数的UDF

    另外还有一种实现方案(基于spark1.5,spark2.4待测试):
    继承org.apache.hadoop.hive.ql.exec.UDF


  • 关于UDAF
    简单粗暴的理解,它就是输入多行输出一行的自定义算子,比UDF的功能强大一些
    通过实现抽象类org.apache.spark.sql.expressions.UserDefinedAggregateFunction来实现UDAF算子,并使用sparkSession.udf.register()注册

    另外还有一种实现方案(基于spark1.5,spark2.4待测试):
    先继承org.apache.hadoop.hive.ql.exec.UDAF
    内部静态类实现org.apache.hadoop.hive.ql.exec.UDAFEvaluator


  • 关于UDTF
    简单粗暴的理解,它就是输入一行输出多行的自定义算子,可输出多行多列,又被称为 “表生成函数”
    通过实现抽象类org.apache.hadoop.hive.ql.udf.generic.GenericUDTF来实现UDTF算子,但是似乎无法使用sparkSession.udf.register()注册。注册方法如下:

sparkSession.sql("CREATE TEMPORARY FUNCTION 自定义算子名称 as '算子实现类全限定名称'")
   实现UDTFf还需要注意(基于spark1.5,可能已过时):
   udtf,process方法中对参数需要使用toString,String强转没用
   sparksql子查询必须要有别名
   算子内部使用竖线切分字符串时,需要转义
   udtf调用forward方法,必须传字符串数组,即使只有一个元素

以上就是全部内容,持续更新,请多提宝贵意见

你可能感兴趣的:(Spark)