Spark(28) -- SparkSQL自定义函数(UDF、UDAF、UDTF)

类似于hive当中的自定义函数,我们在spark当中,如果内置函数不够我们使用,我们同样可以使用自定义函数来实现我们的功能,spark当中的自定义函数,同样的也有

  • UDF(User-Defined-Function),即最基本的自定义函数,类似to_char,to_date等
  • UDAF(User- Defined Aggregation Funcation),用户自定义聚合函数,类似在group by之后使用的sum,avg等
  • UDTF(User-Defined Table-Generating Functions),用户自定义生成函数,有点像stream里面的flatMap

无论Hive还是SparkSQL分析处理数据时,往往需要使用函数,SparkSQL模块本身自带很多实现公共功能的函数,在org.apache.spark.sql.functions中。SparkSQL与Hive一样支持定义函数:UDF和UDAF,尤其是UDF函数在实际项目中使用最为广泛。
回顾Hive中自定义函数有三种类型:

  • 第一种:UDF(User-Defined-Function) 函数
    • 一对一的关系,输入一个值经过函数以后输出一个值;
    • 在Hive中继承UDF类,方法名称为evaluate,返回值不能为void,其实就是实现一个方法;
  • 第二种:UDAF(User-Defined Aggregation Function) 聚合函数
    • 多对一的关系,输入多个值输出一个值,通常与groupBy联合使用;
  • 第三种:UDTF(User-Defined Table-Generating Functions) 函数
    • 一对多的关系,输入一个值输出多个值(一行变为多行);
    • 用户自定义生成函数,有点像flatMap;

目前来说Spark 框架各个版本及各种语言对自定义函数的支持:
Spark(28) -- SparkSQL自定义函数(UDF、UDAF、UDTF)_第1张图片
在SparkSQL中,目前仅仅支持UDF函数和UDAF函数:

  • UDF函数:一对一关系;
  • UDAF函数:聚合函数,通常与group by 分组函数连用,多对一关系;

由于SparkSQL数据分析有两种方式:DSL编程和SQL编程,所以定义UDF函数也有两种方式,不同方式可以在不同分析中使用。

1. UDF实战[掌握]

使用SparkSession中udf方法定义和注册函数,在SQL中使用,使用如下方式定义:
Spark(28) -- SparkSQL自定义函数(UDF、UDAF、UDTF)_第2张图片

范例演示: 将姓名转换为小写,调用String中toLowerCase方法。有数据格式如下:

helloworld
abc
study
smallWORD

将每一行数据转换成大写

代码实现如下:


import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.api.java.UDF1
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.sql.{
     DataFrame, SQLContext, SparkSession}

case class Small(line:String)

object SparkFunction {
     
  def main(args: Array[String]): Unit = {
     
    //获取sparkSession
    val sparkSession: SparkSession = SparkSession.builder().appName("sparkFunction").master("local[2]").getOrCreate()
  //通过sparkSession得到sparkContext
    val sparkContext: SparkContext = sparkSession.sparkContext
    //读取文件内容,获取RDD
    val fileRdd: RDD[String] = sparkContext.textFile("file:///........\\udf.txt")
    //配合我们的样例类,将我们的每一行文件内容转换成一个样例类
    val smallRdd: RDD[Small] = fileRdd.map( x => Small(x))
    //导入spark的隐式转换
    import sparkSession.implicits._
    //将我们的RDD转换成DataFrame
    val smallDF: DataFrame = smallRdd.toDF()
    //df注册成为一张临时表
    smallDF.createOrReplaceTempView("small_table")
    //通过sparkSession进行UDF的注册,将我们的小写转换成大写
    sparkSession.udf.register("smallToBigger", new UDF1[String,String]() {
     
      @throws[Exception]
      override def call(t1: String): String = {
     
        t1.toUpperCase()
      }
    }, DataTypes.StringType)
    //使用UDF函数
    sparkSession.sql("select line, smallToBigger(line) as biggerLine from small_table").show()
    sparkSession.stop()
  }
}

Lambda表达式实现:

sparkSession.udf.register("smaller1", (x: String) => (x.toUpperCase()))
sparkSession.sql("select line,smaller1(line) from small_table ").show()
smallRdd.select($"line", callUDF("smaller1", $"line")).show()

2. UDAF实战

需求:现有json格式数据内容如下

{
     "name":"Michael","salary":3000}
{
     "name":"Andy","salary":4500}
{
     "name":"Justin","salary":3500}
{
     "name":"Berta","salary":4000}

求取平均工资
第一个变量,放工资累加和
第二个变量,放一共有多少条数据

代码实现如下:

import org.apache.spark.sql.{
     DataFrame, Row, SparkSession}
import org.apache.spark.sql.expressions.{
     MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

class SparkFunctionUDAF extends UserDefinedAggregateFunction{
     
  //输入的数据类型的schema
  override def inputSchema: StructType = {
     
     StructType(StructField("input",LongType)::Nil)
  }
  //缓冲区数据类型schema,说白了就是转换之后的数据的schema
//求解平均值“总金额,总人数”
  override def bufferSchema: StructType = {
     
    StructType(StructField("sum",LongType)::StructField("total",LongType)::Nil)
  }
  //返回值的数据类型---总金额/总人数 可能会有小数,返回doubleType
  override def dataType: DataType = {
     
    DoubleType
  }
  //确定是否相同的输入会有相同的输出
  override def deterministic: Boolean = {
     
    true
  }
  //初始化内部数据结构
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
     
    buffer(0) = 0L  //总金额都累加到Buffer0中,初始值为0
    buffer(1) = 0L  //总人数都累加到buffer1参数中
  }
  //更新数据内部结构
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
     
    //所有的金额相加
    buffer(0) = buffer.getLong(0) + input.getLong(0)
    //一共有多少条数据
    buffer(1) = buffer.getLong(1) + 1 //能否累加10?--不能,+!为了求解人数
  }
  //来自不同分区的数据进行合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
     
    buffer1(0) =buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }
  //计算输出数据值
  override def evaluate(buffer: Row): Any = {
     
    buffer.getLong(0).toDouble / buffer.getLong(1)
  }
}

object SparkFunctionUDAF {
     
  def main(args: Array[String]): Unit = {
     
    //获取sparkSession
    val sparkSession: SparkSession = SparkSession.builder().appName("sparkUDAF").master("local[2]").getOrCreate()
    //通过sparkSession读取json文件得到DataFrame
    val employeeDF: DataFrame = sparkSession.read.json("file:///F:\\scala与spark课件资料教案\\spark课程\\3、spark第三天\\资料\\udaf.txt")
    //通过DataFrame创建临时表
    employeeDF.createOrReplaceTempView("employee_table")
    //注册我们的自定义UDAF函数
    sparkSession.udf.register("avgSal",new SparkFunctionUDAF)
    //调用我们的自定义UDAF函数
    sparkSession.sql("select avgSal(salary) from employee_table").show()

    sparkSession.close()

  }
}

你可能感兴趣的:(Spark,spark)