UDF分为三种:
UDF :输入一行,返回一个结果 ;一对一;比如定义一个函数,功能是输入一个IP地址,返回一个对应的省份
UDAF:输入多行,返回一行;aggregate(聚合),count,sum这些是spark自带的聚合函数,但是复杂的业务,要自己定义
UDTF:输入一行,返回多行(hive);一对多;sparkSQL中没有UDTF,spark中用flatMap即可实现该功能
1.定义函数
2.注册函数
SparkSession.udf.register():只在sql()中有效
functions.udf():对DataFrame API均有效
3.函数调用
自定义函数实例
需求:用户行为喜好个数统计
新建hobbies.txt文件内容:
hobbies.txt:
alice jogging,Coding,cooking
lina travel,dance
输出数据格式:
alice jogging,Coding,cooking 3
lina travel,dance 2
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.12</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>2.11.8</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId>
<version>2.1.1</version>
</dependency>
<dependency>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
<version>1.2.17</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<version>2.1.1</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_2.11</artifactId>
<version>2.1.1</version>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>5.1.36</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.21</version>
</dependency>
</dependencies>
scala代码及结果:
import org.apache.spark.sql.SparkSession
//样例类
case class Hobbies(name:String,hobbies:String)
object SparkUDFDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("SparkUDFDemo").master("local[*]").getOrCreate()
val sc = spark.sparkContext
//读入文本文件RDD
val rdd = sc.textFile("in/hobbies.txt")
需要手动导入一个隐式转换,否则RDD无法转换成DF
import spark.implicits._
//分割、装入样例类中再转成DF文件
val df = rdd.map(x=>x.split(" ")).map(x=>Hobbies(x(0),x(1))).toDF()
df.printSchema()
df.show()
//创建临时视图表 表名为temptable
df.createOrReplaceTempView("temptable")
注册自定义函数,注意是匿名函数
//注册自定义函数名为hobby_num 以及编写隐式方法按照","分割 统计爱好数量
spark.udf.register("hobby_num",
(v:String)=>v.split(",").size
)
//spark sql 在查询语句中 使用自定义函数
//hobby_num(hobbies) 函数名为hobby_num,操作的字段名为hobbies
val frame = spark.sql("select name,hobbies,hobby_num(hobbies) as hobby_num from temptable")
frame.show()
}
}
//输出结果
root
|-- name: string (nullable = true)
|-- hob: string (nullable = true)
+-----+--------------------+
| name| hob|
+-----+--------------------+
|alice|jogging,Coding,co...|
| lina| travel,dance|
+-----+--------------------+
+-----+--------------------+---------+
| name| hob|hobby_num|
+-----+--------------------+---------+
|alice|jogging,Coding,co...| 3|
| lina| travel,dance| 2|
+-----+--------------------+---------+
新建user.json 文件内容:
{
"id": 1001, "name": "foo", "sex": "man", "age": 20}
{
"id": 1002, "name": "bar", "sex": "man", "age": 24}
{
"id": 1003, "name": "baz", "sex": "man", "age": 18}
{
"id": 1004, "name": "foo1", "sex": "woman", "age": 17}
{
"id": 1005, "name": "bar2", "sex": "woman", "age": 19}
{
"id": 1006, "name": "baz3", "sex": "woman", "age": 20}
scala代码及结果:
package sql1118
import org.apache.spark.sql.{
Row, SparkSession}
import org.apache.spark.sql.expressions.{
MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
//自定义UDAF(多进一出)函数:读取json文件,求平均年龄
class MyAgeAvgFunction extends UserDefinedAggregateFunction{
//缓存区数据表结构
override def inputSchema: StructType = {
new StructType().add("age",LongType)
//StructType(StructField("age",LongType)::Nil)
}
//聚合函数返回值数据表结构
override def bufferSchema: StructType = {
new StructType().add("sum",LongType).add("count",LongType)
//StructType(StructField("sum",LongType)::StructField("count",LongType)::Nil)
}
//聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
//初始化 sum为0 count为0
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0L
buffer(1)=0L
}
//给聚合函数传入一条新数据进行处理
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0)=buffer.getLong(0)+input.getLong(0)
buffer(1)=buffer.getLong(1)+1
}
//合并聚合函数缓冲区
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//总年龄和
buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
//年龄参数的个数和
buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
}
//计算最终结果
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble/buffer.getLong(1).toDouble
}
}
object SparkUDAFDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("SparkUDAFDemo").master("local[*]").getOrCreate()
val sc = spark.sparkContext
import spark.implicits._
val df = spark.read.json("in/user.json")
df.printSchema()
df.show()
//创建并注册自定义usaf函数
val myUdaf = new MyAgeAvgFunction
spark.udf.register("myAvgAge",myUdaf)
df.createTempView("userinfo")
val resultDF = spark.sql("select myAvgAge(age) as avgage from userinfo group by sex")
resultDF.printSchema()
resultDF.show()
}
}
//输出结果:
root
|-- age: long (nullable = true)
|-- id: long (nullable = true)
|-- name: string (nullable = true)
|-- sex: string (nullable = true)
+---+----+----+-----+
|age| id|name| sex|
+---+----+----+-----+
| 20|1001| foo| man|
| 24|1002| bar| man|
| 18|1003| baz| man|
| 17|1004|foo1|woman|
| 19|1005|bar2|woman|
| 20|1006|baz3|woman|
+---+----+----+-----+
root
|-- avgage: double (nullable = true)
+------------------+
| avgage|
+------------------+
|20.666666666666668|
|18.666666666666668|
+------------------+
新建udtf.txt文件内容:
01//zs//Hadoop scala spark hive hbase
02//ls//Hadoop scala kafka hive hbase Oozie
03//ww//Hadoop scala spark hive sqoop
scala代码及结果:
package sql1118
import org.apache.hadoop.hive.ql.exec.UDFArgumentException
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
import org.apache.hadoop.hive.serde2.objectinspector.{
ObjectInspector, ObjectInspectorFactory, StructObjectInspector}
import java.util
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.spark.sql.SparkSession
//自定义UDTF函数 一输入多输出
/**
* 对传入的参数进行初始化
* 判断参数个数/类型
* 初始化表结构
*/
class myUDTF extends GenericUDTF {
//重写初始化 initialize:初始化 Inspector:检察员
override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
if(argOIs.length!=1) {
throw new UDFArgumentException( //argument:争论
"有且只能有一个参数传入"
)
}
if(argOIs(0).getCategory()!=ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentException(
"参数类型不匹配"
)
}
//初始化表结构
//创建数组列表存储表字段
val fieldNames = new util.ArrayList[String]
val fieldOIs = new util.ArrayList[ObjectInspector]()
//表字段
fieldNames.add("type")
// primitive:原始 表字段数据类型
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
//StructObject:结构体对象(目标)
//将表结构两部分聚合在一起
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,fieldOIs)
}
//需求:
//传入 Hadoop scala kafka hive hbase Oozie
//输出 HEAD type
// Hadoop
// scala
// kafka
// hive
// hbase
// Oozie
/**
* 对数据处理的代码
* 如果是多列的话,可以将每一行的数据存入数组中,然后将数组传入forward,
* forward每调用一次都会产生一行数据
*/
override def process(objects: Array[AnyRef]): Unit = {
//将字符串切分成单个字符的数组
val strings = objects(0).toString.split(" ")
println(strings)
//循环遍历 字符
for(str <- strings) {
val tmp = new Array[String](1)
tmp(0)=str
forward(tmp) //forward是转发的意思
}
}
//方法调用完毕时关闭方法
override def close() : Unit = {
}
}
object SparkUDTFDemo {
def main(args: Array[String]): Unit = {
//一定要 .enableHiveSupport()
val spark = SparkSession.builder().master("local[*]")
.appName("SparkUDTFDemo").enableHiveSupport().getOrCreate()
val sc = spark.sparkContext
import spark.implicits._
val lines = sc.textFile("in/udtf.txt")
lines.collect.foreach(println)
val stuDF = lines.map(x => x.split("//"))
.filter(x => x(1).equals("ls")).map(x => (x(0), x(1), x(2))).toDF("id", "name", "subject")
stuDF.printSchema()
stuDF.show()
stuDF.createOrReplaceTempView("student")
//spark 无法注册UDTF函数 只能hive
// spark.udf.register("myUDTF",new myUDTF)
//创建临时函数 as 'sql1118.myUDTF' 中的sql1118是你的UDTF函数所在的包名
spark.sql("create temporary function myUDTF as 'sql1118.myUDTF'")
//sparksql 查询 使用自定义myUDTF函数 传入参数 列名:subject 进行 打印输出 课程信息
val resultDF = spark.sql("select myUDTF(subject) from student")
resultDF.printSchema()
resultDF.show()
}
}
//输出结果:
01//zs//Hadoop scala spark hive hbase
02//ls//Hadoop scala kafka hive hbase Oozie
03//ww//Hadoop scala spark hive sqoop
root
|-- id: string (nullable = true)
|-- name: string (nullable = true)
|-- subject: string (nullable = true)
+---+----+--------------------+
| id|name| subject|
+---+----+--------------------+
| 02| ls|Hadoop scala kafk...|
+---+----+--------------------+
root
|-- type: string (nullable = true)
[Ljava.lang.String;@788dbfc4
+------+
| type|
+------+
|Hadoop|
| scala|
| kafka|
| hive|
| hbase|
| Oozie|
+------+