sparksql 自定义udf、udaf、udtf函数详细案例

sparksql 自定义udf、udaf、udtf函数详细案例

1、udf函数

// 注册函数    
spark.udf.register("prefix1", (name: String) => {
    "Name:" + name
})
// 使用函数
spark.sql("select *,prefix1(name) from users").show()

2、udaf函数

(1) spark 3.0.0之后

package com.yyds.Test

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator

/**
  * 自定义udaf函数 (spark 3.0.0之后)
  */
class _01_MyAvgNew extends Aggregator[Double, (Double, Int), Double]{
  // 初始值
  override def zero: (Double, Int) = (0.0, 0)
  // 每一个分区局部聚合的方法
  override def reduce(b: (Double, Int), a: Double): (Double, Int) = {
    (b._1 + a, b._2 + 1)
  }
  // 全局聚合调用的方法
  override def merge(b1: (Double, Int), b2: (Double, Int)): (Double, Int) = {
    (b1._1 + b2._1, b1._2 + b2._2)
  }
  // 计算的最终结果
  override def finish(reduction: (Double, Int)): Double = {
    reduction._1 / reduction._2
  }
  // 中间结果的encoder
  override def bufferEncoder: Encoder[(Double, Int)] = {
    Encoders.tuple(Encoders.scalaDouble, Encoders.scalaInt)
  }

  // 返回结果的encoder
  override def outputEncoder: Encoder[Double] = {
    Encoders.scalaDouble
  }
}
package com.yyds.Test

import com.yyds.common.SparkUtils
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{DataFrame, SparkSession}


/**
  * spark sql自定义udaf函数
  *{"id": 1001, "name": "foo", "sex": "man", "age": 20}
   {"id": 1002, "name": "bar", "sex": "man", "age": 24}
   {"id": 1003, "name": "baz", "sex": "man", "age": 18}
   {"id": 1004, "name": "foo1", "sex": "woman", "age": 17}
   {"id": 1005, "name": "bar2", "sex": "woman", "age": 19}
   {"id": 1006, "name": "baz3", "sex": "woman", "age": 20}
  *
  */
object _01_SparkSql_UDAF {

  def main(args: Array[String]): Unit = {
    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)

    // 创建spark的执行环境
    val spark: SparkSession = SparkUtils.sparkSessionWithNoHive(this.getClass.getSimpleName)

    // 自定义UDAF函数
    import org.apache.spark.sql.functions._
    spark.udf.register("myAvg", udaf(new _01_MyAvgNew))
    // 读取json中的数据
    val df: DataFrame = spark.read.json("file:\\D:\\works\\spark-sql\\data\\udaf.json")

    df.createTempView("students")
    // 使用自定义的udaf函数
    spark.sql(
      """
        |
        |select
        |  sex,
        |  myAvg(age) as avg
        |from
        |  students
        |group by sex
        |
        |
      """.stripMargin).show(2)
  }
}

结果如下:
+-----+------------------+
|  sex|               avg|
+-----+------------------+
|  man|20.666666666666668|
|woman|18.666666666666668|
+-----+------------------+

(2)spark 3.0.0之前

package com.yyds.Test

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}


/**
  * 自定义udaf函数 (spark 3.0.0之前)
  */
class _01_MyAvgOld extends UserDefinedAggregateFunction{

  // 聚合函数的输入数据结构
  override def inputSchema: StructType = {
    new StructType().add("age",LongType)
  }

  // 缓冲区的数据结构
  override def bufferSchema: StructType = {
    new StructType().add("sum",LongType).add("count",LongType)
  }

  // 聚合函数返回值的数据结构
  override def dataType: DataType = DoubleType

  //相同输入是否总是得到相同输出
  override def deterministic: Boolean = true

  // 初始化缓冲区
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 0L
  }

  // 给聚合函数传入一条数据进行处理
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getLong(0) + input.getLong(0)
    buffer(1) = buffer.getLong(1) + 1
  }

  // 合并缓冲区数据
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  // 计算最终的结果
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0).toDouble / buffer.getLong(1)
  }
}

package com.yyds.Test

import com.yyds.common.SparkUtils
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{DataFrame, SparkSession}

/**
  * spark sql自定义udaf函数
  *
  *
  */
object _01_SparkSql_UDAF {

  def main(args: Array[String]): Unit = {
    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)

    // 创建spark的执行环境
    val spark: SparkSession = SparkUtils.sparkSessionWithNoHive(this.getClass.getSimpleName)

    // 自定义UDAF函数
//    import org.apache.spark.sql.functions._
//    spark.udf.register("myAvg", udaf(new _01_MyAvgNew))
    spark.udf.register("myAvg", new _01_MyAvgOld)
    // 读取json中的数据
    val df: DataFrame = spark.read.json("file:\\D:\\works\\spark-sql\\data\\udaf.json")

    df.createTempView("students")
    // 使用自定义的udaf函数
    spark.sql(
      """
        |
        |select
        |  sex,
        |  myAvg(age) as avg
        |from
        |  students
        |group by sex
        |
        |
      """.stripMargin).show(2)
  }
}

-- 结果如下
+-----+------------------+
|  sex|               avg|
+-----+------------------+
|  man|20.666666666666668|
|woman|18.666666666666668|
+-----+------------------+

3、udtf函数

package com.yyds.Test

import java.util
import org.apache.hadoop.hive.ql.exec.UDFArgumentException
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory, StructObjectInspector}

/**
  * 自定义udtf函数
  *
  *   通过实现抽象类org.apache.hadoop.hive.ql.udf.generic.GenericUDTF来实现UDTF算子,
  *   但是似乎无法使用sparkSession.udf.register()注册
  *
  *   可以使用 sparkSession.sql("CREATE TEMPORARY FUNCTION 自定义算子名称 as '算子实现类全限定名称'") 进行注册
  */
class _02_MyUdtf extends GenericUDTF{

  //这个方法的作用:1、输入参数校验 2、输出列定义,可以多于1列,相当于可以生成多行多列数据
  override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
    if(argOIs.length!=1){
      throw new UDFArgumentException("有且只能有一个参数传入")
    }
    if (argOIs(0).getCategory!=ObjectInspector.Category.PRIMITIVE){
      throw new UDFArgumentException("参数类型不匹配")
    }
    val fieldNames=new util.ArrayList[String]
    val fieldOIs=new util.ArrayList[ObjectInspector]

    // 名称
    fieldNames.add("type")

    //这里定义的是输出列字段类型
    fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)

    ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,fieldOIs)
  }


  //这是处理数据的方法,入参数组里只有一行数据,即每次调用process方法只处理一行数据
  override def process(objects: Array[AnyRef]): Unit = {

    //将字符串切分成单个字符的数组
    val strings: Array[String] = objects(0).toString.split("\\s+")
    for (str<-strings){
      val tmp: Array[String] = new Array[String](1)
      tmp(0) = str
      //调用forward方法,必须传字符串数组,即使只有一个元素
      forward(tmp)
    }
  }

  override def close(): Unit = {

  }
}

package com.yyds.Test

import com.yyds.common.SparkUtils
import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}


/**
  * spark sql自定义udtf函数
  *
  * 目前Spark 内部不直接支持 udtf, 可以实现通过 hive 的 UDTF, 并注册函数实现。
  * 
  * 原始数据:
  * 01//zs//Hadoop scala spark hive hbase
    02//ls//Hadoop scala kafka hive hbase Oozie
    03//ww//Hadoop scala spark hive sqoop
  *
  */
object _02_SparkSql_UDTF {

  def main(args: Array[String]): Unit = {
    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)

    // 创建spark的执行环境
    val spark: SparkSession = SparkUtils.sparkSessionWithHive(this.getClass.getSimpleName) // 注意:此时需要hive

    import spark.implicits._
    val sc: SparkContext = spark.sparkContext
    val lines: RDD[String] = sc.textFile("file:\\D:\\works\\spark-sql\\data\\udtf.txt")

    val stuDf: DataFrame = lines.map(_.split("//")).filter(x=>x(1).equals("ls")).map(x=>(x(0),x(1),x(2))).toDF("id","name","class")

    stuDf.printSchema()

    stuDf.createOrReplaceTempView("student")

    // 注册udtf函数
    spark.sql("CREATE TEMPORARY FUNCTION MyUDTF AS 'com.yyds.Test._02_MyUdtf'")
    val resultDF: DataFrame = spark.sql("select MyUDTF(class) from student")

    resultDF.show(50)
  }
}


-- 结果如下
+------+
|  type|
+------+
|Hadoop|
| scala|
| kafka|
|  hive|
| hbase|
| Oozie|
+------+

你可能感兴趣的:(#,spark,scala,spark,开发语言)