1. JDK 1.8
2. Spark 2.1
不同的业务需要不同的处理函数,所以spark也支持用户自定义函数来做专用的处理。这里的自定义函数分两大类:用户已定义函数(UDF)和用户自定义聚合函数(UDAF)。
用户自定义函数比较简单,写起来就是个普通的scala函数,只不过在spark中使用的时候需要单独注册一下。
直接看例子吧。
scala> val df=Range(0,10).toSeq.toDF("id")
df: org.apache.spark.sql.DataFrame = [id: int]
scala> df.show
+---+
| id|
+---+
| 0|
| 1|
| 2|
| 3|
| 4|
| 5|
| 6|
| 7|
| 8|
| 9|
+---+
##定义一个函数,对给定的整数列都加100
scala> def add100(value:Int):Int = { value + 100 }
add100: (value: Int)Int
##注册成自定义sql函数
scala> spark.udf.register("add100", add100(_:Int))
res1: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(,IntegerType,Some(List(IntegerType)))
##调用上面写的自定义函数add100(value:Int)
scala> df.selectExpr("id", "add100(id) as new_id").show
+---+------+
| id|new_id|
+---+------+
| 0| 100|
| 1| 101|
| 2| 102|
| 3| 103|
| 4| 104|
| 5| 105|
| 6| 106|
| 7| 107|
| 8| 108|
| 9| 109|
+---+------+
要注意的是,用spark.udf.register注册的函数,不能用作dataset的函数使用。需要用udf类重新注册一下。
##直接用的话,会类型不匹配的。
scala> df.select(add100($"id")).show
:28: error: type mismatch;
found : org.apache.spark.sql.ColumnName
required: Int
df.select(add100($"id")).show
##正确用法,用udf注册
scala> val add100_func=udf(add100 _)
add100_func: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(,IntegerType,Some(List(IntegerType)))
现在相当于有了一个add100_func的函数,类型是UserDefinedFunction
scala> df.select($"id", add100_func($"id").as("new_id")).show
+---+------+
| id|new_id|
+---+------+
| 0| 100|
| 1| 101|
| 2| 102|
| 3| 103|
| 4| 104|
| 5| 105|
| 6| 106|
| 7| 107|
| 8| 108|
| 9| 109|
| 10| 110|
| 11| 111|
| 12| 112|
| 13| 113|
+---+------+
好了,UDF就说到这,挺简单的。下面的UDAF比较起来,复杂多了。
和UDF比起来,就多了一个A:聚合,Aggregation。其实聚合函数很常见,平时写SQL,求和啊,求均值啊这些都是。但是,自己写UDAF,比起写UDF可是麻烦多了。想想也是,一般的UDF,就是处理一行数据中的一列或多列,做个变换后返回。而UDAF是针对多行数据来处理的,最后只输出一行结果,操作本来就复杂些。
要实现一个UDAF功能,有两种方式:一种是从UserDefinedAggregateFunction类继承,一种是从Aggregator类继承。这两种方式基本上类似,前者是非类型安全的,但是比较灵活,不需要传入整行数据,只要传需要做聚合的列就可以了。后者是强类型,api看起来友好一些,但是,对于列很多的情况,比较麻烦。我个人比较倾向于使用UserDefinedAggregateFunction类的继承实现。
从UserDefinedAggregateFunction类继承,需要实现8个成员方法。
成员方法 | 释义 |
---|---|
inputSchema: StructType | 函数的输入参数的类型定义 |
dataType: DataType | 函数的返回值类型定义 |
bufferSchema: StructType | 内部缓存,记录临时变量等 |
deterministic: Boolean | 这是一个确定性的指示。就是说,是否给定输入后,每次运行的结果都一致。通常都是true |
initialize(buffer: MutableAggregationBuffer): Unit | 初始化函数。典型的功能就是变量清零之类的 |
update(buffer: MutableAggregationBuffer, input: Row): Unit | 更新函数。在同一个partition内的数据一行一行的被调用到该函数做更新处理 |
merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit | 合并函数。各个partition更新完所有数据后,通过merge函数合并 |
evaluate(buffer: Row): Any | 最终的求值函数,输出为dataType类型 |
基本上就是按照上述的过程来实现功能。定义输入输出,定义中间缓存的数据结构,定义初始化,更新,合并,最后求值。
来看一个实现整型求和函数的代码示例:
import java.util.ArrayList
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.spark.TaskContext
class UDAF_Sum extends UserDefinedAggregateFunction {
//1.定义输入数据的类型
override def inputSchema = StructType(Array(
StructField("input", LongType)
))
//2.定义中间数据的类型
override def bufferSchema = StructType(Array(
//temp_sum很明显是保存部分和
StructField("temp_sum", LongType),
//ele_array这里是用来记录当前处理了哪些元素,用来帮助观察整个计算过程。
StructField("ele_array", DataTypes.createArrayType(DataTypes.LongType))
))
//3.定义返回结果的类型
override def dataType: DataType = LongType
//4.输出的确定性指示,一般都是true
override def deterministic = true
//5.定义初始化函数,就是些初始值的处理。
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//初始化,因为是求和的,所以和的初值显然为0
buffer(0) = 0L
//记录当前已处理的所有输入的数
buffer(1) = new ArrayList[Long]()
}
//6.定义update函数,对于一个partition来说,里面的每条数据都会经过update
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val par_id = TaskContext.getPartitionId()
println(s"------partition $par_id update begin------")
println(s"partition $par_id update input: $input")
buffer(0) = buffer.getLong(0) + input.getLong(0)
val tmpList = new ArrayList(buffer.getList[Long](1))
tmpList.add(input.getLong(0))
buffer(1) = tmpList
println(s"partition $par_id update output: buffer = $buffer")
println(s"-----partition $par_id update end-----------")
}
//7.定义merge函数,处理所有partition的全局聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//每个分区计算的结果进行相加
val par_id = TaskContext.getPartitionId()
println(s"------partition $par_id merge begin------")
println(s"partition $par_id merge input: buffer1 = $buffer1")
println(s"partition $par_id merge input: buffer2 = $buffer2")
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
val tmpList = new ArrayList(buffer1.getList[Long](1))
tmpList.addAll( buffer2.getList[Long](1))
buffer1(1) = tmpList
println(s"partition $par_id merge output: buffer1 = $buffer1")
println(s"-----partition $par_id merge end-----------")
}
//8.定义evaluate函数,返回最终的结果
override def evaluate(buffer: Row): Any = {
println("evaluate: " + buffer)
buffer.getLong(0)
}
}
在spark shell里面,可以用:paste命令把整段代码一次性复制进去,我们来运行一下看看结果:
[root@ecs-930c spark-2.1.0-bin-hadoop2.7]# bin/spark-shell --master local[2]
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
19/07/21 16:55:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
19/07/21 16:55:51 WARN ObjectStore: Failed to get database global_temp, returning NoSuchObjectException
Spark context Web UI available at http://192.168.1.153:4040
Spark context available as 'sc' (master = local[2], app id = local-1563699348275).
Spark session available as 'spark'.
Welcome to
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
/___/ .__/\_,_/_/ /_/\_\ version 2.1.0
/_/
Using Scala version 2.11.8 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_201)
Type in expressions to have them evaluated.
Type :help for more information.
scala> :paste
// Entering paste mode (ctrl-D to finish)
///////////////////////////////////////////////
上面的代码直接粘贴,就不重复了,粘贴后按Ctrl-D结束
///////////////////////////////////////////////
// Exiting paste mode, now interpreting.
import java.util.ArrayList
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.spark.TaskContext
defined class UDAF_Sum
//实例化一个函数对象出来
scala> val udaf_sum = new UDAF_Sum
udaf_sum: UDAF_Sum = UDAF_Sum@2caa9666
//生成测试数据的dataset,数字0到10,字段名"id"
scala> val df=Range(0,10).toSeq.toDF("id")
df: org.apache.spark.sql.DataFrame = [id: int]
//这里我把每个数字在哪个partition打印出来了。
//这里能看到是2个partition,partition 0里面包含了0,1,2,3,4,partition 1里面包含了5,6,7,8,9
scala> df.foreachPartition(par => par.foreach(x=>println("partition "+TaskContext.getPartitionId.toString+":"+x)))
partition 0:[0]
partition 1:[5]
partition 0:[1]
partition 1:[6]
partition 0:[2]
partition 1:[7]
partition 0:[3]
partition 1:[8]
partition 0:[4]
partition 1:[9]
//现在来调用我们创建的UDAF函数,注册的名字是udaf_sum,传入的列是id
scala> df.select(udaf_sum($"id")).show
------partition 0 update begin------
------partition 1 update begin------
partition 0 update input: [0]
partition 1 update input: [5]
partition 1 update output: buffer = [5,WrappedArray(5)]
partition 0 update output: buffer = [0,WrappedArray(0)]
-----partition 1 update end-----------
-----partition 0 update end-----------
------partition 1 update begin------
------partition 0 update begin------
partition 1 update input: [6]
partition 0 update input: [1]
partition 1 update output: buffer = [11,WrappedArray(5, 6)]
-----partition 1 update end-----------
------partition 1 update begin------
partition 1 update input: [7]
partition 0 update output: buffer = [1,WrappedArray(0, 1)]
-----partition 0 update end-----------
partition 1 update output: buffer = [18,WrappedArray(5, 6, 7)]
-----partition 1 update end-----------
------partition 0 update begin------
------partition 1 update begin------
partition 0 update input: [2]
partition 1 update input: [8]
partition 0 update output: buffer = [3,WrappedArray(0, 1, 2)]
partition 1 update output: buffer = [26,WrappedArray(5, 6, 7, 8)]
-----partition 0 update end-----------
-----partition 1 update end-----------
------partition 0 update begin------
partition 0 update input: [3]
------partition 1 update begin------
partition 1 update input: [9]
partition 0 update output: buffer = [6,WrappedArray(0, 1, 2, 3)]
-----partition 0 update end-----------
partition 1 update output: buffer = [35,WrappedArray(5, 6, 7, 8, 9)] <-----到这里为止,partition 1更新完成,总共5条记录
------partition 0 update begin------
-----partition 1 update end-----------
partition 0 update input: [4]
partition 0 update output: buffer = [10,WrappedArray(0, 1, 2, 3, 4)] <-----到这里为止,partition 0更新完成,总共也是5条记录
-----partition 0 update end-----------
------partition 0 merge begin------ <------这里开始进入merge阶段
partition 0 merge input: buffer1 = [0,WrappedArray()]
partition 0 merge input: buffer2 = [10,WrappedArray(0, 1, 2, 3, 4)]
partition 0 merge output: buffer1 = [10,WrappedArray(0, 1, 2, 3, 4)]
-----partition 0 merge end-----------
------partition 0 merge begin------
partition 0 merge input: buffer1 = [10,WrappedArray(0, 1, 2, 3, 4)]
partition 0 merge input: buffer2 = [35,WrappedArray(5, 6, 7, 8, 9)]
partition 0 merge output: buffer1 = [45,WrappedArray(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)] <------ merge完成,总和是45,总共10个元素
-----partition 0 merge end-----------
evaluate: [45,WrappedArray(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)]
+------------+
|udaf_sum(id)|
+------------+
| 45|
+------------+
配个图看清楚一点:
嗯,这个小系列拖拖拉拉的,总算是完结啦~ 你还想知道啥?给我留言吧