Spark自定义函数(UDF、UDAF、UDTF)

目录

  • 1、自定义标准函数(UDF)
  • 2、自定义聚合函数(UDAF)
  • 3、自定义表生成函数(UDTF)

Spark提供大量内置函数供开发者使用,也可以自定义函数使用。

Spark自定义函数步骤:
1、定义函数
2、注册函数
SparkSession.udf.register():只在sql()中有效
functions.udf():对DataFrame API均有效
3、函数调用

1、自定义标准函数(UDF)

在D:\test\t\目录下有文件hobbies.txt,文件内容:

alice	jogging,Coding,cooking
lina	travel,dance

需求:用户行为喜好个数统计
要求输出格式:

alice	jogging,Coding,cooking	3
lina	travel,dance			2
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{
     DataFrame, SparkSession}

object SparkUDFDemo {
     
  //样例类
  case class Hobbies(name:String,hobbies: String)

  def main(args: Array[String]): Unit = {
     
    val spark :SparkSession= SparkSession.builder()
      .master("local[1]")
      .appName("SparkUDFDemo")
      .getOrCreate()
 	val sc:SparkContext = spark.sparkContext

    //需要手动导入一个隐式转换,否则RDD无法转换成DF
    import spark.implicits._
   
    val rdd:RDD[String] = sc.textFile("D:\\test\\t\\hobbies.txt")
    val df:DataFrame = rdd.map(x=>x.split("\t")).map(x=>Hobbies(x(0),x(1))).toDF()

    //df.printSchema()
    //df.show()

    df.registerTempTable("hobbies")
    //注册自定义函数,注意是匿名函数
    spark.udf.register("hoby_num",(s:String)=>s.split(",").size)

    val frame:DataFrame = spark.sql("select name,hobbies,hoby_num(hobbies) as hobnum from hobbies")
    frame.show()
  }
}

输出:

+-----+--------------------+------+
| name|             hobbies|hobnum|
+-----+--------------------+------+
|alice|jogging,Coding,co...|     3|
| lina|        travel,dance|     2|
+-----+--------------------+------+

2、自定义聚合函数(UDAF)

UDAF(User Defined Aggregate Function),即用户定义的聚合函数,聚合函数和普通函数的区别是什么呢,普通函数是接受一行输入产生一个输出,聚合函数是接受一组(一般是多行)输入然后产生一个输出,即将一组的值想办法聚合一下。

UDAF使用:

继承UserDefinedAggregateFunction

使用UserDefinedAggregateFunction的步骤:

  1. 自定义类继承UserDefinedAggregateFunction,对每个阶段方法做实现

  2. 在spark中注册UDAF,为其绑定一个名字

  3. 然后就可以在sql语句中使用上面绑定的名字调用

在D:\test\t\目录下有文件user.json,文件内容:

{"id": 1001, "name": "foo", "sex": "man", "age": 20}
{"id": 1002, "name": "bar", "sex": "man", "age": 24}
{"id": 1003, "name": "baz", "sex": "man", "age": 18}
{"id": 1004, "name": "foo1", "sex": "woman", "age": 17}
{"id": 1005, "name": "bar2", "sex": "woman", "age": 19}
{"id": 1006, "name": "baz3", "sex": "woman", "age": 20}

需求:计算平均年龄。

import org.apache.spark.sql.{
     Row, SparkSession, types}
import org.apache.spark.sql.expressions.{
     MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

object SparkUDAFDemo {
     
  def main(args: Array[String]): Unit = {
     
    val spark = SparkSession.builder()
      .master("local[2]")
      .appName("SparkUDAFDemo")
      .getOrCreate()
    import spark.implicits._
    val sc = spark.sparkContext
    val df = spark.read.json("D:\\test\\t\\user.json")
    
    //创建并注册自定义udaf函数
    val myUdaf=new MyAgeAvgFunction
    spark.udf.register("myAvgAge",myUdaf)

    df.createTempView("userinfo")
    val resultDF = spark.sql("select myAvgAge(age) as avg_age from userinfo group by sex")
    resultDF.printSchema()
    resultDF.show()
  }
}

class MyAgeAvgFunction extends UserDefinedAggregateFunction{
     
  // 聚合函数的输入数据结构
  override def inputSchema: StructType = {
     
    new StructType().add("age",LongType)
    //另一种写法
    //StructType(StructField("age",LongType)::Nil)
  }
  // 缓存区数据结构
  override def bufferSchema: StructType = {
     
    new StructType().add("sum",LongType).add("count",LongType)
    //另一种写法
   // StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
  }
  // 聚合函数返回值数据结构
  override def dataType: DataType = DoubleType

  // 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
  override def deterministic: Boolean =true
  // 初始化缓冲区
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
     
    buffer(0)=0L
    buffer(1)=0L
  }
  // 给聚合函数传入一条新数据进行处理
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
     
    buffer(0)=buffer.getLong(0)+input.getLong(0)
    buffer(1)=buffer.getLong(1)+1
  }

  // 合并聚合函数缓冲区
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
     
   //总年龄数
    buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
    //个数
    buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
  }
  // 计算最终结果
  override def evaluate(buffer: Row): Any = {
     
    buffer.getLong(0).toDouble/buffer.getLong(1)
  }
}

结果:

root
 |-- avg_age: double (nullable = true)

+------------------+
|           avg_age|
+------------------+
|20.666666666666668|
|18.666666666666668|
+------------------+

参考博客:https://www.cnblogs.com/cc11001100/p/9471859.html
(该博客里还记录了另一种方法:继承Aggregator)

3、自定义表生成函数(UDTF)

在D:\test\t\目录下有文件udtf.txt,文件内容:

01//zs//Hadoop scala spark hive hbase
02//ls//Hadoop scala kafka hive hbase Oozie
03//ww//Hadoop scala spark hive sqoop

需求:将ls的Hadoop scala kafka hive hbase Oozie生成如下形式:

  //      type           --(表头)
  //      Hadoop
  //      scala
  //      kafaka
  //       hive
  //      hbase
  //      Oozie
import java.util

import org.apache.hadoop.hive.ql.exec.UDFArgumentException
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.serde2.objectinspector.{
     ObjectInspector, ObjectInspectorFactory, PrimitiveObjectInspector, StructObjectInspector}
import org.apache.spark.sql.SparkSession

object SparkUDTFDemo {
     
  def main(args: Array[String]): Unit = {
     
    val spark = SparkSession.builder()
      .master("local[1]")
      .enableHiveSupport()		//需要hive支持
      .appName("SparkUDTFDemo")
      .getOrCreate()
    val sc = spark.sparkContext

    import spark.implicits._

    val lines = sc.textFile("D:\\test\\t\\udtf.txt")
    val stuDF = lines.map(_.split("//")).filter(x => x(1).equals("ls"))
      .map(x => (x(0), x(1), x(2))).toDF("id", "name", "class")
    //stuDF.printSchema()
    //stuDF.show()

    stuDF.createTempView("student")
    
    spark.sql("CREATE TEMPORARY FUNCTION myUDTF AS 'kb09.sql.myUDTF'")
    //注意AS后面的类如果在包里一定要加包名!!!
    val resultDF = spark.sql("select myUDTF(class) from student")

    resultDF.show()
  }
}

class myUDTF extends GenericUDTF{
     

  override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
     
    if (argOIs.length!=1){
     
      throw new UDFArgumentException("有且只能有一个参数传入")
    }
    if (argOIs(0).getCategory!=ObjectInspector.Category.PRIMITIVE){
     
      throw new UDFArgumentException("参数类型不匹配")
    }
    val fieldNames =new util.ArrayList[String]
    val fieldOIs =new util.ArrayList[ObjectInspector]

    fieldNames.add("type")
    //这里定义的是输出字段的类型
    fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
    ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,fieldOIs)


  }
	//传入 Hadoop scala kafaka hive hbase Oozie
  override def process(objects: Array[AnyRef]): Unit ={
     
    //将字符串切分成单个字符的数组
    val strings = objects(0).toString.split(" ")
    println(strings)
    for (str<- strings){
     
      val tmp = new Array[String](1)
      tmp(0)=str
      forward(tmp)
    }
  }
  override def close(): Unit = {
     }
}

输出:

[Ljava.lang.String;@6d0e1408
+------+
|  type|
+------+
|Hadoop|
| scala|
| kafka|
|  hive|
| hbase|
| Oozie|
+------+

你可能感兴趣的:(spark,spark)