用户自定义函数类别分为以下三种:
1).UDF:输入一行,返回一个结果(一对一),在上篇案例 使用SparkSQL实现根据ip地址计算归属地二 中实现的自定义函数就是UDF,输入一个十进制的ip地址,返回一个省份
2).UDTF:输入一行,返回多行(一对多),在SparkSQL中没有,因为Spark中使用flatMap即可实现这个功能
3).UDAF:输入多行,返回一行,这里的A是aggregate,聚合的意思,如果业务复杂,需要自己实现聚合函数
下面就来介绍如何自定义UDAF聚合函数
以一个实际案例来介绍,这个案例是求几何平均数的,几何平均数不知道的可以去百度百科看,简单来说就是求n个数乘积的开n次方,但是如果这里的n很大,在单机上根本就运算不了怎么办,我们可以在Spark集群上执行这个任务
如图所示: 在集群的机器的分区内执行计算出各自的n和t,然后汇总到一起再执行Math.pow来计算几何平均数
具体代码实现:
package cn.ysjh0014.SparkSql
import java.lang
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Dataset, Row, SparkSession}
object UDAFTest {
def main(args: Array[String]): Unit = {
val session: SparkSession = SparkSession.builder().appName("UDAFTest").master("local[*]").getOrCreate()
val udaf = new UDAFys
//注册函数
// session.udf.register("udaf",udaf)
val range: Dataset[lang.Long] = session.range(1, 11)
// range.createTempView("table")
// val df = session.sql("SELECT udaf(id) result FROM table")
import session.implicits._
val df = range.agg(udaf($"id").as("geomean"))
df.show()
session.stop()
}
}
class UDAFys extends UserDefinedAggregateFunction {
//输入数据的类型
override def inputSchema: StructType = StructType(List(
StructField("value", DoubleType)
))
//产生中间结果的数据类型
override def bufferSchema: StructType = StructType(List(
//相乘之后返回的积
StructField("project", DoubleType),
//参与运算数字的个数
StructField("Num", LongType)
))
//最终返回的结果类型
override def dataType: DataType = DoubleType
//确保一致性,一般用true
override def deterministic: Boolean = true
//指定初始值
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//相乘的初始值,这里的要和上边的中间结果的类型和位置相对应
buffer(0) = 1.0
//参与运算数字个数的初始值
buffer(1) = 0L
}
//每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的计算)
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//每有一个数字参与运算就进行相乘(包含中间结果)
buffer(0) = buffer.getDouble(0) * input.getDouble(0)
//参与运算的数字个数更新
buffer(1) = buffer.getLong(1) + 1L
}
//全局聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//每个分区计算的结果进行相乘
buffer1(0) = buffer1.getDouble(0) * buffer2.getDouble(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
//计算最终的结果
override def evaluate(buffer: Row): Any = {
math.pow(buffer.getDouble(0), 1.toDouble / buffer.getLong(1))
}
}
package test.udaf
import org.apache.log4j.Logger
import org.apache.spark.Partition
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.{DataFrame, SparkSession}
object UDAFTest2 {
val logger = Logger.getLogger(this.getClass)
val spark: SparkSession = SparkSession
.builder()
.appName("local-test")
.master("local[4]")
//.enableHiveSupport()
//.config("spark.shuffle.service.enabled", true)
//.config("spark.driver.maxResultSize", "4G")
//.config("spark.sql.parquet.writeLegacyFormat", true)
.getOrCreate()
spark.sparkContext.setLogLevel("warn")
import spark.implicits._
def main(args: Array[String]): Unit = {
val dataSource = Seq(
("11111110", "F2", "100"),
("11111111", "F2", "200"),
("11111112", "F2", "300"),
("11111113", "F2", "400"),
("11111114", "F2", "500"),
("11111115", "F2", "600"),
("11111116", "F2", "700"),
("11111117", "F2", "800"),
("11111118", "F2", "900"),
("11111119", "F2", "1000"),
("22222220", "F3", "100"),
("22222221", "F3", "200"),
("22222222", "F3", "300"),
("22222223", "F3", "400"),
("22222224", "F3", "500"),
("22222225", "F3", "600"),
("22222226", "F3", "700"),
("22222227", "F3", "800"),
("22222228", "F3", "900"),
("22222229", "F3", "1000")
)
val rawDF: DataFrame = spark.createDataFrame(dataSource).toDF("user", "platform", "fe")
.withColumn("fe", $"fe".cast(DoubleType))
val copyDF = rawDF
.select(
(rawDF.columns).map(i => col(i).alias(s"${i}2")): _*
)
val joinDF = rawDF.join(copyDF, rawDF("platform") === copyDF("platform2"), "inner")
.filter($"user" =!= $"user2")
val partitions: Int = joinDF.rdd.getNumPartitions
val partitions1: Array[Partition] = joinDF.rdd.partitions
logger.warn("========>>>>>>>> " + partitions)
partitions1.foreach(e => {
println(e.index)
})
joinDF.printSchema()
joinDF.show(1000,false)
val cal_fe_sigma_udf: UserDefinedAggregateFunction = spark.udf.register("cal_fe_sigma", new SigmaUdafTest())
val tmpDF = joinDF.groupBy("user", "platform")
.agg(
cal_fe_sigma_udf($"user", $"fe", $"user2", $"fe2")
)
tmpDF.printSchema()
tmpDF.show(false)
}
}
package test.udaf
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import scala.collection.mutable.ListBuffer
class SigmaUdafTest extends UserDefinedAggregateFunction {
// 聚合函数的输入数据结构
// override def inputSchema: StructType = {
// StructType(StructField("esn", StringType)
// :: StructField("fe", DoubleType)
// :: StructField("fe2", DoubleType)
// :: Nil)
// }
override def inputSchema: StructType = {
new StructType()
.add("user", StringType)
.add("fe", DoubleType)
.add("user2", StringType)
.add("fe2", DoubleType)
}
//缓存数据类型 即在聚合计算过程当中的中间结果数据类型
override def bufferSchema: StructType = {
new StructType()
.add("userArray", DataTypes.createArrayType(StringType))
.add("feArray", DataTypes.createArrayType(DoubleType))
.add("user2Array", DataTypes.createArrayType(StringType))
.add("fe2Array", DataTypes.createArrayType(DoubleType))
}
// 聚合函数返回值数据结构
override def dataType: DataType = {
StringType
}
// 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
override def deterministic: Boolean = true
// 初始化缓冲区
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, Seq[String]()) // 待计算的数据存入一个序列中,以备后用,初始化为一个空序列
buffer.update(1, Seq[Double]()) // 待计算的数据存入一个序列中,以备后用,初始化为一个空序列
buffer.update(2, Seq[String]()) // 待计算的数据存入一个序列中,以备后用,初始化为一个空序列
buffer.update(3, Seq[Double]()) // 待计算的数据存入一个序列中,以备后用,初始化为一个空序列
}
// 更新缓存的数据,输入一条数据后, 更新到缓冲区
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0, buffer.getSeq(0) :+ input.getString(0))
buffer.update(1, buffer.getSeq(1) :+ input.getDouble(1))
buffer.update(2, buffer.getSeq(2) :+ input.getString(2))
buffer.update(3, buffer.getSeq(3) :+ input.getDouble(3))
}
// 合并两个聚合缓冲区并将更新后的缓冲区值存储回“buffer1” 相当于先局部聚合再全局聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getSeq[String](0) ++ buffer2.getSeq[String](0))
buffer1.update(1, buffer1.getSeq[Double](1) ++ buffer2.getSeq[Double](1))
buffer1.update(2, buffer1.getSeq[String](2) ++ buffer2.getSeq[String](2))
buffer1.update(3, buffer1.getSeq[Double](3) ++ buffer2.getSeq[Double](3))
}
override def evaluate(buffer: Row): Any = {
val userArrayAll = buffer.getSeq[String](0)
val feArrayAll = buffer.getSeq[Double](1)
val user2ArraAll = buffer.getSeq[String](2)
val fe2ArrayAll = buffer.getSeq[Double](3)
val user = userArrayAll(0)
val fe = feArrayAll(0)
val minFe: Double = (fe - 100d)
val maxFe: Double = (fe + 100d)
var lst = Seq[String]()
for(i <- user2ArraAll.indices) {
if (minFe <= fe2ArrayAll(i) && fe2ArrayAll(i) <= maxFe) {
// user2_fe2
lst = lst :+ (user2ArraAll(i) + "_" + fe2ArrayAll(i))
}
}
// user:user2_fe2,user2_fe2...
user + ":" + lst.mkString(",")
}
def getSortedIndexArray(occurrenceDateTimeArray: Array[String]): ListBuffer[Int] = {
val lst = ListBuffer[String]()
for(i <- occurrenceDateTimeArray.indices) {
lst.append(occurrenceDateTimeArray(i) + "_" + i)
}
val lst1 = lst.sorted
val sortedIndexArray = ListBuffer[Int]()
for(i <- lst1.indices) {
sortedIndexArray.append(lst1(i).split("_")(1).toInt)
}
sortedIndexArray
}
}