1. JDK 1.8
2. Spark 2.1
scala> val df=Range(0,10).toSeq.toDF("id")
df: org.apache.spark.sql.DataFrame = [id: int]
scala> df.show
| id|
| 0|
| 1|
| 2|
| 3|
| 4|
| 5|
| 6|
| 7|
| 8|
| 9|
scala> def add100(value:Int):Int = { value + 100 }
add100: (value: Int)Int
scala> spark.udf.register("add100", add100(_:Int))
res1: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(,IntegerType,Some(List(IntegerType)))
scala> df.selectExpr("id", "add100(id) as new_id").show
| id|new_id|
| 0| 100|
| 1| 101|
| 2| 102|
| 3| 103|
| 4| 104|
| 5| 105|
| 6| 106|
| 7| 107|
| 8| 108|
| 9| 109|
scala> df.select(add100($"id")).show
:28: error: type mismatch;
found : org.apache.spark.sql.ColumnName
required: Int
scala> val add100_func=udf(add100 _)
add100_func: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(,IntegerType,Some(List(IntegerType)))
scala> df.select($"id", add100_func($"id").as("new_id")).show
| id|new_id|
| 0| 100|
| 1| 101|
| 2| 102|
| 3| 103|
| 4| 104|
| 5| 105|
| 6| 106|
| 7| 107|
| 8| 108|
| 9| 109|
| 10| 110|
| 11| 111|
| 12| 112|
| 13| 113|
成员方法 | 释义 |
inputSchema: StructType | 函数的输入参数的类型定义 |
dataType: DataType | 函数的返回值类型定义 |
bufferSchema: StructType | 内部缓存,记录临时变量等 |
deterministic: Boolean | 这是一个确定性的指示。就是说,是否给定输入后,每次运行的结果都一致。通常都是true |
initialize(buffer: MutableAggregationBuffer): Unit | 初始化函数。典型的功能就是变量清零之类的 |
update(buffer: MutableAggregationBuffer, input: Row): Unit | 更新函数。在同一个partition内的数据一行一行的被调用到该函数做更新处理 |
merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit | 合并函数。各个partition更新完所有数据后,通过merge函数合并 |
evaluate(buffer: Row): Any | 最终的求值函数,输出为dataType类型 |
import java.util.ArrayList
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.spark.TaskContext
class UDAF_Sum extends UserDefinedAggregateFunction {
override def inputSchema = StructType(Array(
StructField("input", LongType)
override def bufferSchema = StructType(Array(
StructField("temp_sum", LongType),
StructField("ele_array", DataTypes.createArrayType(DataTypes.LongType))
override def dataType: DataType = LongType
override def deterministic = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = new ArrayList[Long]()
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val par_id = TaskContext.getPartitionId()
println(s"------partition $par_id update begin------")
println(s"partition $par_id update input: $input")
buffer(0) = buffer.getLong(0) + input.getLong(0)
val tmpList = new ArrayList(buffer.getList[Long](1))
buffer(1) = tmpList
println(s"partition $par_id update output: buffer = $buffer")
println(s"-----partition $par_id update end-----------")
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val par_id = TaskContext.getPartitionId()
println(s"------partition $par_id merge begin------")
println(s"partition $par_id merge input: buffer1 = $buffer1")
println(s"partition $par_id merge input: buffer2 = $buffer2")
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
val tmpList = new ArrayList(buffer1.getList[Long](1))
tmpList.addAll( buffer2.getList[Long](1))
buffer1(1) = tmpList
println(s"partition $par_id merge output: buffer1 = $buffer1")
println(s"-----partition $par_id merge end-----------")
override def evaluate(buffer: Row): Any = {
println("evaluate: " + buffer)
在spark shell里面,可以用:paste命令把整段代码一次性复制进去,我们来运行一下看看结果:
[root@ecs-930c spark-2.1.0-bin-hadoop2.7]# bin/spark-shell --master local[2]
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
19/07/21 16:55:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
19/07/21 16:55:51 WARN ObjectStore: Failed to get database global_temp, returning NoSuchObjectException
Spark context Web UI available at
Spark context available as 'sc' (master = local[2], app id = local-1563699348275).
Spark session available as 'spark'.
Welcome to
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
/___/ .__/\_,_/_/ /_/\_\ version 2.1.0
Using Scala version 2.11.8 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_201)
Type in expressions to have them evaluated.
Type :help for more information.
scala> :paste
// Entering paste mode (ctrl-D to finish)
// Exiting paste mode, now interpreting.
import java.util.ArrayList
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.spark.TaskContext
defined class UDAF_Sum
scala> val udaf_sum = new UDAF_Sum
udaf_sum: UDAF_Sum = UDAF_Sum@2caa9666
scala> val df=Range(0,10).toSeq.toDF("id")
df: org.apache.spark.sql.DataFrame = [id: int]
//这里能看到是2个partition,partition 0里面包含了0,1,2,3,4,partition 1里面包含了5,6,7,8,9
scala> df.foreachPartition(par => par.foreach(x=>println("partition "+TaskContext.getPartitionId.toString+":"+x)))
partition 0:[0]
partition 1:[5]
partition 0:[1]
partition 1:[6]
partition 0:[2]
partition 1:[7]
partition 0:[3]
partition 1:[8]
partition 0:[4]
partition 1:[9]
scala> df.select(udaf_sum($"id")).show
------partition 0 update begin------
------partition 1 update begin------
partition 0 update input: [0]
partition 1 update input: [5]
partition 1 update output: buffer = [5,WrappedArray(5)]
partition 0 update output: buffer = [0,WrappedArray(0)]
-----partition 1 update end-----------
-----partition 0 update end-----------
------partition 1 update begin------
------partition 0 update begin------
partition 1 update input: [6]
partition 0 update input: [1]
partition 1 update output: buffer = [11,WrappedArray(5, 6)]
-----partition 1 update end-----------
------partition 1 update begin------
partition 1 update input: [7]
partition 0 update output: buffer = [1,WrappedArray(0, 1)]
-----partition 0 update end-----------
partition 1 update output: buffer = [18,WrappedArray(5, 6, 7)]
-----partition 1 update end-----------
------partition 0 update begin------
------partition 1 update begin------
partition 0 update input: [2]
partition 1 update input: [8]
partition 0 update output: buffer = [3,WrappedArray(0, 1, 2)]
partition 1 update output: buffer = [26,WrappedArray(5, 6, 7, 8)]
-----partition 0 update end-----------
-----partition 1 update end-----------
------partition 0 update begin------
partition 0 update input: [3]
------partition 1 update begin------
partition 1 update input: [9]
partition 0 update output: buffer = [6,WrappedArray(0, 1, 2, 3)]
-----partition 0 update end-----------
partition 1 update output: buffer = [35,WrappedArray(5, 6, 7, 8, 9)] <-----到这里为止,partition 1更新完成,总共5条记录
------partition 0 update begin------
-----partition 1 update end-----------
partition 0 update input: [4]
partition 0 update output: buffer = [10,WrappedArray(0, 1, 2, 3, 4)] <-----到这里为止,partition 0更新完成,总共也是5条记录
-----partition 0 update end-----------
------partition 0 merge begin------ <------这里开始进入merge阶段
partition 0 merge input: buffer1 = [0,WrappedArray()]
partition 0 merge input: buffer2 = [10,WrappedArray(0, 1, 2, 3, 4)]
partition 0 merge output: buffer1 = [10,WrappedArray(0, 1, 2, 3, 4)]
-----partition 0 merge end-----------
------partition 0 merge begin------
partition 0 merge input: buffer1 = [10,WrappedArray(0, 1, 2, 3, 4)]
partition 0 merge input: buffer2 = [35,WrappedArray(5, 6, 7, 8, 9)]
partition 0 merge output: buffer1 = [45,WrappedArray(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)] <------ merge完成,总和是45,总共10个元素
-----partition 0 merge end-----------
evaluate: [45,WrappedArray(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)]
| 45|
嗯,这个小系列拖拖拉拉的,总算是完结啦~ 你还想知道啥?给我留言吧