Spark SQL 自定义函数实例(UDF、UDAF、UDTF)

Spark SQL 自定义函数实例(UDF、UDAF、UDTF)

    • UDF函数分类及说明
    • 自定义UDF函数及使用
    • maven依赖dependencies
    • 自定义UDAF函数及使用
    • hive UDTF函数写法

UDF函数分类及说明

UDF分为三种:
UDF :输入一行,返回一个结果 ;一对一;比如定义一个函数,功能是输入一个IP地址,返回一个对应的省份
UDAF:输入多行,返回一行;aggregate(聚合),count,sum这些是spark自带的聚合函数,但是复杂的业务,要自己定义
UDTF:输入一行,返回多行(hive);一对多;sparkSQL中没有UDTF,spark中用flatMap即可实现该功能

自定义UDF函数及使用

1.定义函数
2.注册函数
SparkSession.udf.register():只在sql()中有效
functions.udf():对DataFrame API均有效
3.函数调用

自定义函数实例
需求:用户行为喜好个数统计
新建hobbies.txt文件内容:

hobbies.txt:
alice	jogging,Coding,cooking
lina	travel,dance

输出数据格式:

alice	jogging,Coding,cooking	3
lina	travel,dance		2

maven依赖dependencies

 <dependencies>
    <dependency>
      <groupId>junit</groupId>
      <artifactId>junit</artifactId>
      <version>4.12</version>
      <scope>test</scope>
    </dependency>
    <dependency>
      <groupId>org.scala-lang</groupId>
      <artifactId>scala-library</artifactId>
      <version>2.11.8</version>
    </dependency>
    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-core_2.11</artifactId>
      <version>2.1.1</version>
    </dependency>
    <dependency>
      <groupId>log4j</groupId>
      <artifactId>log4j</artifactId>
      <version>1.2.17</version>
    </dependency>
    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-sql_2.11</artifactId>
      <version>2.1.1</version>
    </dependency>
    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-hive_2.11</artifactId>
      <version>2.1.1</version>
    </dependency>
    <dependency>
      <groupId>mysql</groupId>
      <artifactId>mysql-connector-java</artifactId>
      <version>5.1.36</version>
    </dependency>
    <dependency>
      <groupId>org.slf4j</groupId>
      <artifactId>slf4j-api</artifactId>
      <version>1.7.21</version>
    </dependency>
  </dependencies>

scala代码及结果:

import org.apache.spark.sql.SparkSession

//样例类
case class Hobbies(name:String,hobbies:String)

object SparkUDFDemo {
     
  def main(args: Array[String]): Unit = {
     
    val spark = SparkSession.builder().appName("SparkUDFDemo").master("local[*]").getOrCreate()
    val sc = spark.sparkContext

    //读入文本文件RDD
    val rdd = sc.textFile("in/hobbies.txt")

    需要手动导入一个隐式转换,否则RDD无法转换成DF
    import spark.implicits._

    //分割、装入样例类中再转成DF文件
    val df = rdd.map(x=>x.split(" ")).map(x=>Hobbies(x(0),x(1))).toDF()

    df.printSchema()
    df.show()

    //创建临时视图表 表名为temptable
    df.createOrReplaceTempView("temptable")

    注册自定义函数,注意是匿名函数
    //注册自定义函数名为hobby_num 以及编写隐式方法按照","分割 统计爱好数量
    spark.udf.register("hobby_num",
      (v:String)=>v.split(",").size
    )
    
    //spark sql 在查询语句中 使用自定义函数
    //hobby_num(hobbies)  函数名为hobby_num,操作的字段名为hobbies
    val frame = spark.sql("select name,hobbies,hobby_num(hobbies) as hobby_num from temptable")
    frame.show()
  }
}


//输出结果
root
 |-- name: string (nullable = true)
 |-- hob: string (nullable = true)

+-----+--------------------+
| name|                 hob|
+-----+--------------------+
|alice|jogging,Coding,co...|
| lina|        travel,dance|
+-----+--------------------+

+-----+--------------------+---------+
| name|                 hob|hobby_num|
+-----+--------------------+---------+
|alice|jogging,Coding,co...|        3|
| lina|        travel,dance|        2|
+-----+--------------------+---------+

自定义UDAF函数及使用

新建user.json 文件内容:

{
     "id": 1001, "name": "foo", "sex": "man", "age": 20}
{
     "id": 1002, "name": "bar", "sex": "man", "age": 24}
{
     "id": 1003, "name": "baz", "sex": "man", "age": 18}
{
     "id": 1004, "name": "foo1", "sex": "woman", "age": 17}
{
     "id": 1005, "name": "bar2", "sex": "woman", "age": 19}
{
     "id": 1006, "name": "baz3", "sex": "woman", "age": 20}

scala代码及结果:

package sql1118

import org.apache.spark.sql.{
     Row, SparkSession}
import org.apache.spark.sql.expressions.{
     MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._


//自定义UDAF(多进一出)函数:读取json文件,求平均年龄
class MyAgeAvgFunction extends UserDefinedAggregateFunction{
     
  //缓存区数据表结构
  override def inputSchema: StructType = {
     
    new StructType().add("age",LongType)
    //StructType(StructField("age",LongType)::Nil)
  }

  //聚合函数返回值数据表结构
  override def bufferSchema: StructType = {
     
    new StructType().add("sum",LongType).add("count",LongType)
    //StructType(StructField("sum",LongType)::StructField("count",LongType)::Nil)
  }


  //聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
  override def dataType: DataType = DoubleType

  override def deterministic: Boolean = true

  //初始化 sum为0  count为0
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
     
    buffer(0)=0L
    buffer(1)=0L
  }

  //给聚合函数传入一条新数据进行处理
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
     
    buffer(0)=buffer.getLong(0)+input.getLong(0)
    buffer(1)=buffer.getLong(1)+1
  }

  //合并聚合函数缓冲区
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
     
    //总年龄和
    buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
    //年龄参数的个数和
    buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
  }

  //计算最终结果
  override def evaluate(buffer: Row): Any = {
     
    buffer.getLong(0).toDouble/buffer.getLong(1).toDouble
  }
}


object SparkUDAFDemo {
     
  def main(args: Array[String]): Unit = {
     
    val spark = SparkSession.builder().appName("SparkUDAFDemo").master("local[*]").getOrCreate()
    val sc = spark.sparkContext

    import spark.implicits._

    val df = spark.read.json("in/user.json")
    df.printSchema()
    df.show()

    //创建并注册自定义usaf函数
    val myUdaf = new MyAgeAvgFunction
    spark.udf.register("myAvgAge",myUdaf)

    df.createTempView("userinfo")
    val resultDF = spark.sql("select myAvgAge(age) as avgage from userinfo group by sex")

    resultDF.printSchema()
    resultDF.show()
  }
}


//输出结果:
root
 |-- age: long (nullable = true)
 |-- id: long (nullable = true)
 |-- name: string (nullable = true)
 |-- sex: string (nullable = true)

+---+----+----+-----+
|age|  id|name|  sex|
+---+----+----+-----+
| 20|1001| foo|  man|
| 24|1002| bar|  man|
| 18|1003| baz|  man|
| 17|1004|foo1|woman|
| 19|1005|bar2|woman|
| 20|1006|baz3|woman|
+---+----+----+-----+

root
 |-- avgage: double (nullable = true)

+------------------+
|            avgage|
+------------------+
|20.666666666666668|
|18.666666666666668|
+------------------+

hive UDTF函数写法

新建udtf.txt文件内容:

01//zs//Hadoop scala spark hive hbase
02//ls//Hadoop scala kafka hive hbase Oozie
03//ww//Hadoop scala spark hive sqoop

scala代码及结果:

package sql1118

import org.apache.hadoop.hive.ql.exec.UDFArgumentException
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
import org.apache.hadoop.hive.serde2.objectinspector.{
     ObjectInspector, ObjectInspectorFactory, StructObjectInspector}
import java.util
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.spark.sql.SparkSession

//自定义UDTF函数 一输入多输出
/**
       * 对传入的参数进行初始化
       * 判断参数个数/类型
       * 初始化表结构
       */
class myUDTF extends GenericUDTF {
     
  //重写初始化 initialize:初始化   Inspector:检察员
  override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
     
    if(argOIs.length!=1) {
     
      throw new UDFArgumentException(  //argument:争论
        "有且只能有一个参数传入"
      )
    }

    if(argOIs(0).getCategory()!=ObjectInspector.Category.PRIMITIVE) {
     
      throw new UDFArgumentException(
        "参数类型不匹配"
      )
    }

    //初始化表结构
    //创建数组列表存储表字段
    val fieldNames = new util.ArrayList[String]
    val fieldOIs = new util.ArrayList[ObjectInspector]()

    //表字段
    fieldNames.add("type")

    // primitive:原始  表字段数据类型
    fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)

    //StructObject:结构体对象(目标)
    //将表结构两部分聚合在一起
    ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,fieldOIs)
  }
  //需求:
  //传入   Hadoop scala kafka hive hbase Oozie
  //输出   HEAD  type

  //            Hadoop
  //            scala
  //             kafka
  //             hive
  //             hbase
  //             Oozie
  /**
      * 对数据处理的代码
      * 如果是多列的话,可以将每一行的数据存入数组中,然后将数组传入forward,
      * forward每调用一次都会产生一行数据
      */
  override def process(objects: Array[AnyRef]): Unit = {
     
    //将字符串切分成单个字符的数组
    val strings = objects(0).toString.split(" ")
    println(strings)

    //循环遍历 字符
    for(str <- strings) {
     
      val tmp = new Array[String](1)
      tmp(0)=str
      forward(tmp)  //forward是转发的意思
    }
  }

  //方法调用完毕时关闭方法
  override def close() : Unit = {
     }
}

object SparkUDTFDemo {
     
  def main(args: Array[String]): Unit = {
     
    //一定要 .enableHiveSupport()
   val spark = SparkSession.builder().master("local[*]")
     .appName("SparkUDTFDemo").enableHiveSupport().getOrCreate()

    val sc = spark.sparkContext

    import spark.implicits._
    val lines = sc.textFile("in/udtf.txt")
    lines.collect.foreach(println)

    val stuDF = lines.map(x => x.split("//"))
      .filter(x => x(1).equals("ls")).map(x => (x(0), x(1), x(2))).toDF("id", "name", "subject")
    stuDF.printSchema()
    stuDF.show()

    stuDF.createOrReplaceTempView("student")

    //spark 无法注册UDTF函数 只能hive
   // spark.udf.register("myUDTF",new myUDTF)

    //创建临时函数  as 'sql1118.myUDTF'  中的sql1118是你的UDTF函数所在的包名
    spark.sql("create temporary function myUDTF as 'sql1118.myUDTF'")
    //sparksql 查询 使用自定义myUDTF函数 传入参数 列名:subject 进行 打印输出 课程信息
    val resultDF = spark.sql("select myUDTF(subject) from student")

    resultDF.printSchema()
    resultDF.show()
  }
}


//输出结果:
01//zs//Hadoop scala spark hive hbase
02//ls//Hadoop scala kafka hive hbase Oozie
03//ww//Hadoop scala spark hive sqoop

root
 |-- id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- subject: string (nullable = true)

+---+----+--------------------+
| id|name|             subject|
+---+----+--------------------+
| 02|  ls|Hadoop scala kafk...|
+---+----+--------------------+

root
 |-- type: string (nullable = true)

[Ljava.lang.String;@788dbfc4
+------+
|  type|
+------+
|Hadoop|
| scala|
| kafka|
|  hive|
| hbase|
| Oozie|
+------+

你可能感兴趣的:(Spark,Spark自定义函数,UDF,UDTF,UDAF)