UDAF:User Defined Aggregation Function,用户自定义的聚合函数,函数本身作用于数据集合,能够在聚合操作的基础上进行自定义操作。例如:UDF会被SparkSQL中的Catalyst封装成为Expression,最终会通过eval方法来计算输入的数据Row。UDAF有大量的Aggregation之类的操作,对数据进行分组,对批量的数据集合进行操作。
一个UDAF维护一个聚合缓冲区来存储每组输入数据的中间结果。它为每个输入行更新此缓冲区,一旦处理完了所有输入行,基于该聚合缓冲区的值,返回结果。
使用Maven工程方式新建工程,导入Spark 2.2.1及相关的Jar包,需在Pom.xml文件中增加相关的依赖:
Pom.xml文件 依赖:
4.0.0
spark221BookExample
spark221BookExample
1.0-SNAPSHOT
2.11.8
2.2.1
2.8.2
1.2.14
9.2.5.v20141112
2.17
1.8
scala-tools.org
Scala-Tools Maven2 Repository
http://scala-tools.org/repo-releases
scala-tools.org
Scala-Tools Maven2 Repository
http://scala-tools.org/repo-releases
javax.ws.rs
javax.ws.rs-api
2.0
org.scala-lang
scala-library
${scala.version}
org.scala-lang
scala-compiler
${scala.version}
org.scala-lang
scala-reflect
${scala.version}
org.scala-lang
scalap
${scala.version}
junit
junit
4.4
test
org.specs
specs
1.2.5
test
org.apache.spark
spark-core_2.11
${spark.version}
org.apache.spark
spark-launcher_2.11
${spark.version}
org.apache.spark
spark-network-shuffle_2.11
${spark.version}
org.apache.spark
spark-sql_2.11
${spark.version}
org.apache.spark
spark-hive_2.11
${spark.version}
org.apache.spark
spark-catalyst_2.11
${spark.version}
org.apache.spark
spark-streaming-flume-assembly_2.11
${spark.version}
org.apache.spark
spark-streaming-flume_2.11
${spark.version}
org.apache.spark
spark-streaming_2.11
${spark.version}
com.google.guava
guava
14.0.1
org.apache.spark
spark-graphx_2.11
${spark.version}
org.scalanlp
breeze_2.11
0.11.2
compile
junit
junit
commons-math3
org.apache.commons
org.apache.commons
commons-math3
3.4.1
compile
org.apache.spark
spark-mllib_2.11
${spark.version}
org.apache.spark
spark-mllib-local_2.11
${spark.version}
test-jar
test
org.apache.spark
spark-repl_2.11
${spark.version}
org.apache.hadoop
hadoop-client
2.6.0
org.apache.spark
spark-streaming-kafka-0-8_2.10
${spark.version}
org.apache.spark
spark-streaming-flume_2.11
${spark.version}
mysql
mysql-connector-java
5.1.6
org.apache.hive
hive-jdbc
1.2.1
org.apache.httpcomponents
httpclient
4.4.1
org.apache.httpcomponents
httpcore
4.4.1
org.apache.hadoop
hadoop-common
2.6.0
org.apache.hadoop
hadoop-client
2.6.0
org.apache.hadoop
hadoop-hdfs
2.6.0
redis.clients
jedis
${jedis.version}
org.json
json
20090211
com.fasterxml.jackson.core
jackson-core
2.6.3
com.fasterxml.jackson.core
jackson-databind
2.6.3
com.fasterxml.jackson.core
jackson-annotations
2.6.3
com.alibaba
fastjson
1.1.41
fastutil
fastutil
5.0.9
org.eclipse.jetty
jetty-server
${jetty.version}
org.eclipse.jetty
jetty-servlet
${jetty.version}
org.eclipse.jetty
jetty-util
${jetty.version}
org.glassfish.jersey.core
jersey-server
${container.version}
org.glassfish.jersey.containers
jersey-container-servlet-core
${container.version}
org.glassfish.jersey.containers
jersey-container-jetty-http
${container.version}
org.apache.hadoop
hadoop-mapreduce-client-core
2.6.0
org.antlr
antlr4-runtime
4.5.3
org.apache.thrift
libthrift
0.9.3
maven-assembly-plugin
dist
true
jar-with-dependencies
make-assembly
package
single
maven-compiler-plugin
1.7
net.alchim31.maven
scala-maven-plugin
3.2.2
scala-compile-first
process-resources
compile
${scala.version}
incremental
true
-unchecked
-deprecation
-feature
-Xms1024m
-Xmx1024m
-source
${java.version}
-target
${java.version}
-Xlint:all,-serial,-path
org.antlr
antlr4-maven-plugin
4.3
antlr
antlr4
none
src/test/java
true
true
本案例UserDefinedTypedAggregation.scala实现UDAF,通过MyAverage继承至基类Aggregator并实现以下六个方法:
本案例实现的功能是读入职员的薪酬信息文件,计算出职员的平均工资。
1) 数据源文件data/sql/employees.json,数据格式包括姓名、工资。
内容如下。
1. {"name":"Michael","salary":3000}
2. {"name":"Andy","salary":4500}
3. {"name":"Justin","salary":3500}
4. {"name":"Berta","salary":4000}
2) 定义职员Employee的case class类,其成员变量分别为姓名、薪酬等信
息。
3) 定义平均值Average的case class类,其成员变量分别为薪酬求和、职员
人次等信息。
4) 定义object对象 MyAverage 继承至Aggregator[Employee,Average,
Double]。Employee是聚合函数的输入类型,Average是聚合函数中间值进行汇聚的类型,Double是最终结果的输出类型。MyAverage重载实现Aggregator类的6个方法。
5) 构建建SparkSession,导入Spark的隐式转换类spark.implicits._,用于将
一个DataFrame隐式转换为一个DataSet。
6) 使用spark.read.json方法读入职员薪酬信息employees.json,转换为
Employee类型的DataSet。
7) 通过MyAverage.toColumn.name方法设置列名为average_salary,
averageSalary的类型为TypedColumn[Employee, Double]。
8) 通过ds.select(averageSalary),调用averageSalary列计算平均工资值。查询
平均工资。
9) result.show()打印出最终的平均工资结果。
UserDefinedTypedAggregation.scala的源代码:
5. packageorg.apache.spark.examples.sql
6.
7. // $exampleon:typed_custom_aggregation$
8. importorg.apache.spark.sql.expressions.Aggregator
9. import org.apache.spark.sql.Encoder
10. importorg.apache.spark.sql.Encoders
11. importorg.apache.spark.sql.SparkSession
12. // $exampleoff:typed_custom_aggregation$
13.
14. objectUserDefinedTypedAggregation {
15.
16. // $example on:typed_custom_aggregation$
17. case class Employee(name: String, salary:Long)
18. case class Average(var sum: Long, var count:Long)
19.
20. object MyAverage extends Aggregator[Employee,Average, Double] {
21. // A zero value for this aggregation.Should satisfy the property that any b + zero = b
22. def zero: Average = Average(0L, 0L)
23. // Combine two values to produce a newvalue. For performance, the function may modify `buffer`
24. // and return it instead of constructing anew object
25. def reduce(buffer: Average, employee:Employee): Average = {
26. buffer.sum += employee.salary
27. buffer.count += 1
28. buffer
29. }
30. // Merge two intermediate values
31. def merge(b1: Average, b2: Average):Average = {
32. b1.sum += b2.sum
33. b1.count += b2.count
34. b1
35. }
36. // Transform the output of the reduction
37. def finish(reduction: Average): Double =reduction.sum.toDouble / reduction.count
38. // Specifies the Encoder for theintermediate value type
39. def bufferEncoder: Encoder[Average] =Encoders.product
40. // Specifies the Encoder for the final outputvalue type
41. def outputEncoder: Encoder[Double] =Encoders.scalaDouble
42. }
43. // $example off:typed_custom_aggregation$
44.
45. def main(args: Array[String]): Unit = {
46. val spark = SparkSession
47. .builder()
48. .appName("Spark SQL user-defined Datasetsaggregation example")
49. .getOrCreate()
50.
51. import spark.implicits._
52.
53. // $example on:typed_custom_aggregation$
54. val ds =spark.read.json("examples/src/main/resources/employees.json").as[Employee]
55. ds.show()
56. // +-------+------+
57. //| name|salary|
58. // +-------+------+
59. // |Michael| 3000|
60. // | Andy| 4500|
61. // | Justin| 3500|
62. // | Berta| 4000|
63. // +-------+------+
64.
65. // Convert the function to a `TypedColumn`and give it a name
66. val averageSalary = MyAverage.toColumn.name("average_salary")
67. val result = ds.select(averageSalary)
68. result.show()
69. // +--------------+
70. // |average_salary|
71. // +--------------+
72. // | 3750.0|
73. // +--------------+
74. // $example off:typed_custom_aggregation$
75.
76. spark.stop()
77. }
78.
79. }
在IDEA中运行UserDefinedTypedAggregation.scala,结果如下:
Using Spark'sdefault log4j profile: org/apache/spark/log4j-defaults.properties
18/02/21 14:08:10INFO SparkContext: Running Spark version 2.2.1
……
18/02/21 14:08:41INFO DAGScheduler: Job 1 finished: show atUserDefinedTypedAggregation.scala:73, took 0.770808 s
+-------+------+
| name|salary|
+-------+------+
|Michael| 3000|
| Andy| 4500|
| Justin| 3500|
| Berta| 4000|
+-------+------+
…….
18/02/21 14:08:45INFO DAGScheduler: Job 2 finished: show at UserDefinedTypedAggregation.scala:86,took 1.504912 s
+--------------+
|average_salary|
+--------------+
| 3750.0|
+--------------+
本节基于无类型的自定义聚合函数UDAF案例,MyAverageUDAF继承了基类UserDefinedAggregateFunction并实现以下八个方法:
本案例实现的功能也是读入职员的薪酬信息文件,计算出职员的平均工资。但不同于UserDefinedTypedAggregation.scala基于类型自定义聚合函数UDAF的Dataset 操作案例,本案例通过spark.read.json读入employees.json文件以后,没有通过as[Employee]方法转换为DataSet,因此不具备DataSet的强类型。
1) 创建SparkSession。
2) 通过spark.udf.register注册名称为myAverage的自定义UDAF函数MyAverage。
3) 通过spark.read.json读入职工工资信息表employees.json。
4) 调用createOrReplaceTempView方法注册为临时表employees。并查询显示临时表的数据。
5) 通过spark.sql语句调用myAverage(salary)自定义函数,计算平均工资。 result.show()打印最终结果。
UserDefinedUntypedAggregation.scala源代码:
package sparksql
2.
3. // $exampleon:untyped_custom_aggregation$
4. importorg.apache.spark.sql.expressions.{MutableAggregationBuffer,UserDefinedAggregateFunction}
5. importorg.apache.spark.sql.types._
6. importorg.apache.spark.sql.{Row, SparkSession}
7. // $exampleoff:untyped_custom_aggregation$
8.
9. objectUserDefinedUntypedAggregation {
10.
11. // $example on:untyped_custom_aggregation$
12. object MyAverage extendsUserDefinedAggregateFunction {
13. // Data types of input arguments of thisaggregate function
14. def inputSchema: StructType =StructType(StructField("inputColumn", LongType) :: Nil)
15.
16. // Data types of values in the aggregationbuffer
17. def bufferSchema: StructType = {
18. StructType(StructField("sum",LongType) :: StructField("count", LongType) :: Nil)
19. }
20.
21. // The data type of the returned value
22. def dataType: DataType = DoubleType
23.
24. // Whether this function always returns thesame output on the identical input
25. def deterministic: Boolean = true
26.
27. // Initializes the given aggregationbuffer. The buffer itself is a `Row` that in addition to
28. // standard methods like retrieving a valueat an index (e.g., get(), getBoolean()), provides
29. // the opportunity to update its values.Note that arrays and maps inside the buffer are still
30. // immutable.
31. def initialize(buffer:MutableAggregationBuffer): Unit = {
32. buffer(0) = 0L
33. buffer(1) = 0L
34. }
35.
36. // Updates the given aggregation buffer`buffer` with new input data from `input`
37. def update(buffer:MutableAggregationBuffer, input: Row): Unit = {
38. if (!input.isNullAt(0)) {
39. buffer(0) = buffer.getLong(0) +input.getLong(0)
40. buffer(1) = buffer.getLong(1) + 1
41. }
42. }
43.
44. // Merges two aggregation buffers andstores the updated buffer values back to `buffer1`
45. def merge(buffer1:MutableAggregationBuffer, buffer2: Row): Unit = {
46. buffer1(0) = buffer1.getLong(0) +buffer2.getLong(0)
47. buffer1(1) = buffer1.getLong(1) +buffer2.getLong(1)
48. }
49.
50. // Calculates the final result
51. def evaluate(buffer: Row): Double =buffer.getLong(0).toDouble / buffer.getLong(1)
52. }
53.
54. // $example off:untyped_custom_aggregation$
55.
56. def main(args: Array[String]): Unit = {
57. val spark = SparkSession
58. .builder()
59. .master("local")
60. .appName("Spark SQL user-definedDataFrames aggregation example")
61. .getOrCreate()
62.
63. // $example on:untyped_custom_aggregation$
64. // Register the function to access it
65. spark.udf.register("myAverage",MyAverage)
66.
67. val df =spark.read.json("data/sql/employees.json")
68. df.createOrReplaceTempView("employees")
69. df.show()
70. // +-------+------+
71. // | name|salary|
72. // +-------+------+
73. // |Michael| 3000|
74. // | Andy| 4500|
75. // | Justin| 3500|
76. // | Berta| 4000|
77. // +-------+------+
78.
79. val result = spark.sql("SELECTmyAverage(salary) as average_salary FROM employees")
80. result.show()
81. // +--------------+
82. // |average_salary|
83. // +--------------+
84. // | 3750.0|
85. // +--------------+
86. // $example off:untyped_custom_aggregation$
87.
88. spark.stop()
89. }
90.
91. }
在IDEA中运行UserDefinedUntypedAggregation.scala,结果如下:
Using Spark'sdefault log4j profile: org/apache/spark/log4j-defaults.properties
18/02/21 17:09:21INFO SparkContext: Running Spark version 2.2.1
……
18/02/21 17:09:41INFO DAGScheduler: Job 1 finished: show atUserDefinedUntypedAggregation.scala:85, took 0.138434 s
+-------+------+
| name|salary|
+-------+------+
|Michael| 3000|
| Andy| 4500|
| Justin| 3500|
| Berta| 4000|
+-------+------+
……
18/02/21 17:09:43INFO DAGScheduler: Job 2 finished: show atUserDefinedUntypedAggregation.scala:96, took 0.693570 s
+--------------+
|average_salary|
+--------------+
| 3750.0|
+--------------+
2018年新春报喜!热烈祝贺王家林大咖大数据经典传奇著作《SPARK大数据商业实战三部曲》畅销书籍 清华大学出版社发行上市!
本书适合所有Spark学习者和从业人员使用。对于有分布式计算框架应用经验的人员,本书也可以作为Spark高手修炼的参考书籍。同时,本书也特别适合作为高等院校的大数据教材使用。
当当网、京东、淘宝、亚马逊等网店已可购买!欢迎大家购买学习!