用户自定义udf
自定义udf的方式有两种
- SQLContext.udf.register()
- 创建UserDefinedFunction
这两种个方式 使用范围不一样
package com.test.spark
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{Dataset, Row, SparkSession}
/**
* @author Administrator
* 2019/7/22-14:04
*
*/
object TestUdf {
val spark = SparkSession
.builder()
.appName("TestCreateDataset")
.config("spark.some.config.option", "some-value")
.master("local")
.enableHiveSupport()
.getOrCreate()
val sQLContext = spark.sqlContext
import spark.implicits._
def main(args: Array[String]): Unit = {
testudf
}
def testudf() = {
val iptoLong: UserDefinedFunction = getIpToLong()
val ds: Dataset[Row] = spark.read.json("D:\\DATA-LG\\PUBLIC\\TYGQ\\INF\\testJson")
ds.createOrReplaceTempView("table1")
sQLContext.udf.register("addName", sqlUdf(_: String)) //addName 只能在SQL里面用 不能在DSL 里面用
//1.SQL
sQLContext.sql("select *,addName(name) as nameAddName from table1")
.show()
//2.DSL
val addName: UserDefinedFunction = udf((str: String) => ("ip: " + str))
ds.select($"*", addName($"ip").as("ipAddName"))
.show()
//如果自定义函数相对复杂,可以将它分离出去 如iptoLong
ds.select($"*", iptoLong($"ip").as("iptoLong"))
.show()
}
def sqlUdf(name: String): String = {
"name:" + name
}
/**
* 用户自定义 UDF 函数
*
* @return
*/
def getIpToLong(): UserDefinedFunction = {
val ipToLong: UserDefinedFunction = udf((ip: String) => {
val arr: Array[String] = ip.replace(" ", "").replace("\"", "").split("\\.")
var result: Long = 0
var ipl: Long = 0
if (arr.length == 4) {
for (i <- 0 to 3) {
ipl = arr(i).toLong
result |= ipl << ((3 - i) << 3)
}
} else {
result = -1
}
result
})
ipToLong
}
}
输出结果
+---+---------------+---------+--------------+
|age| ip| name| nameAddName|
+---+---------------+---------+--------------+
| 24| 192.168.0.8| lillcol| name:lillcol|
|100| 192.168.255.1| adson| name:adson|
| 39| 192.143.255.1| wuli| name:wuli|
| 20| 192.168.255.1| gu| name:gu|
| 15| 243.168.255.9| ason| name:ason|
| 1| 108.168.255.1| tianba| name:tianba|
| 25|222.168.255.110|clearlove|name:clearlove|
| 30|222.168.255.110|clearlove|name:clearlove|
+---+---------------+---------+--------------+
+---+---------------+---------+-------------------+
|age| ip| name| ipAddName|
+---+---------------+---------+-------------------+
| 24| 192.168.0.8| lillcol| ip: 192.168.0.8|
|100| 192.168.255.1| adson| ip: 192.168.255.1|
| 39| 192.143.255.1| wuli| ip: 192.143.255.1|
| 20| 192.168.255.1| gu| ip: 192.168.255.1|
| 15| 243.168.255.9| ason| ip: 243.168.255.9|
| 1| 108.168.255.1| tianba| ip: 108.168.255.1|
| 25|222.168.255.110|clearlove|ip: 222.168.255.110|
| 30|222.168.255.110|clearlove|ip: 222.168.255.110|
+---+---------------+---------+-------------------+
+---+---------------+---------+----------+
|age| ip| name| iptoLong|
+---+---------------+---------+----------+
| 24| 192.168.0.8| lillcol|3232235528|
|100| 192.168.255.1| adson|3232300801|
| 39| 192.143.255.1| wuli|3230662401|
| 20| 192.168.255.1| gu|3232300801|
| 15| 243.168.255.9| ason|4087938825|
| 1| 108.168.255.1| tianba|1823014657|
| 25|222.168.255.110|clearlove|3735617390|
| 30|222.168.255.110|clearlove|3735617390|
+---+---------------+---------+----------+
用户自定义 UDAF 函数(即聚合函数)
弱类型用户自定义聚合函数
通过继承UserDefinedAggregateFunction
package com.test.spark
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Dataset, Row, SparkSession}
/**
* @author lillcol
* 2019/7/22-15:09
* 弱类型用户自定义聚合函数
*/
object TestUDAF extends UserDefinedAggregateFunction {
// 聚合函数输入参数的数据类型
// :: 用于的是向队列的头部追加数据,产生新的列表,Nil 是一个空的 List,定义为 List[Nothing]
override def inputSchema: StructType = StructType(StructField("age", IntegerType) :: Nil)
//等效于
// override def inputSchema: StructType=new StructType() .add("age", IntegerType).add("name", StringType)
// 聚合缓冲区中值的数据类型
override def bufferSchema: StructType = {
StructType(StructField("sum", IntegerType) :: StructField("count", IntegerType) :: Nil)
}
// UserDefinedAggregateFunction返回值的数据类型。
override def dataType: DataType = DoubleType
// 如果这个函数是确定的,即给定相同的输入,总是返回相同的输出。
override def deterministic: Boolean = true
// 初始化给定的聚合缓冲区,即聚合缓冲区的零值。
override def initialize(buffer: MutableAggregationBuffer): Unit = {
// sum, 总的年龄
buffer(0) = 0
// count, 人数
buffer(1) = 0
}
// 使用来自输入的新输入数据更新给定的聚合缓冲区。
// 每个输入行调用一次。(同一分区)
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getInt(0) + input.getInt(0) //年龄 叠加
buffer(1) = buffer.getInt(1) + 1 //人数叠加
}
// 合并两个聚合缓冲区并将更新后的缓冲区值存储回buffer1。
// 当我们将两个部分聚合的数据合并在一起时,就会调用这个函数。(多个分区)
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0) //年龄 叠加
buffer1(1) = buffer1.getInt(1) + buffer2.getInt(1) //人数叠加
}
override def evaluate(buffer: Row): Any = {
buffer.getInt(0).toDouble / buffer.getInt(1)
}
val spark = SparkSession
.builder()
.appName("Spark SQL basic example")
// .config("spark.some.config.option", "some-value")
.master("local[*]") // 本地测试
.getOrCreate()
import spark.implicits._
def main(args: Array[String]): Unit = {
spark.udf.register("myAvg", TestUDAF)
val ds: Dataset[Row] = spark.read.json("D:\\DATA-LG\\PUBLIC\\TYGQ\\INF\\testJson")
ds.createOrReplaceTempView("table1")
//SQL
spark.sql("select myAvg(age) as avgAge from table1")
.show()
//DSL
val myavg = TestUDAF
ds.select(TestUDAF($"age").as("avgAge"))
.show()
}
}
输出结果:
+------+
|avgAge|
+------+
| 31.75|
+------+
+------+
|avgAge|
+------+
| 31.75|
+------+
强类型用户自定义聚合函数
通过继承Aggregator(是org.apache.spark.sql.expressions 下的 不要引错包了)
package com.test.spark
import org.apache.spark.sql.{Dataset, Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions._
/**
* @author Administrator
* 2019/7/22-16:07
*
*/
// 既然是强类型,可能有 case 类
case class Person(name: String, age: Double, ip: String)
case class Average(var sum: Double, var count: Double)
object MyAverage extends Aggregator[Person, Average, Double] {
// 此聚合的值为零。应该满足任意b + 0 = b的性质。
// 定义一个数据结构,保存工资总数和工资总个数,初始都为0
override def zero: Average = {
Average(0, 0)
}
// 将两个值组合起来生成一个新值。为了提高性能,函数可以修改b并返回它,而不是为b构造新的对象。
// 相同 Execute 间的数据合并(同一分区)
override def reduce(b: Average, a: Person): Average = {
b.sum += a.age
b.count += 1
b
}
// 合并两个中间值。
// 聚合不同 Execute 的结果(不同分区)
override def merge(b1: Average, b2: Average): Average = {
b1.sum += b2.sum
b1.count += b2.count
b1
}
// 计算最终结果
override def finish(reduction: Average): Double = {
reduction.sum.toInt / reduction.count
}
// 为中间值类型指定“编码器”。
override def bufferEncoder: Encoder[Average] = Encoders.product
// 为最终输出值类型指定“编码器”。
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
val spark = SparkSession
.builder()
.appName("Spark SQL basic example")
// .config("spark.some.config.option", "some-value")
.master("local[*]") // 本地测试
.getOrCreate()
import spark.implicits._
def main(args: Array[String]): Unit = {
val ds: Dataset[Person] = spark.read.json("D:\\DATA-LG\\PUBLIC\\TYGQ\\INF\\testJson").as[Person]
ds.show()
val avgAge = MyAverage.toColumn/*.name("avgAge")*///指定该列的别名为avgAge
ds.select(avgAge)//执行avgAge.as("columnName") 汇报org.apache.spark.sql.AnalysisException错误 别名只能在上面指定(目前测试是这样)
.show()
}
}
输出结果:
+---+---------------+---------+
|age| ip| name|
+---+---------------+---------+
| 24| 192.168.0.8| lillcol|
|100| 192.168.255.1| adson|
| 39| 192.143.255.1| wuli|
| 20| 192.168.255.1| gu|
| 15| 243.168.255.9| ason|
| 1| 108.168.255.1| tianba|
| 25|222.168.255.110|clearlove|
| 30|222.168.255.110|clearlove|
+---+---------------+---------+
+------+
|avgAge|
+------+
| 31.75|
+------+
本文为原创文章,转载请注明出处!!!