自定义函数被称为(UDF)
UDF分为三种:
UDF :输入一行,返回一个结果 ;一对一;比如定义一个函数,功能是输入一个IP地址,返回一个对应的省份
UDTF:输入一行,返回多行(hive);一对多;sparkSQL中没有UDTF,spark中用flatMap即可实现该功能
UDAF:输入多行,返回一行;aggregate(聚合),count,sum这些是spark自带的聚合函数,但是复杂的业务,要自己定义
下面来讲解一下UDF和UDAF的使用:
案例:根据IP地址计算出归属地
自定义一个函数,传入一个IP地址,返回一个对应的省份,然后将这个函数进行注册,就可以在SQL语句中使用我们自定义的函数了。
package XXX
import cn.edu360.sparkIpTest.TestIp
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
/**
* Create by 。。。
*
* join的代价太昂贵,而且非常慢,解决思路是将IP规则的表缓存起来(广播变量)
*/
object IpLocationSQL2 {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("IpLocationSQL")
.master("local[4]")
.getOrCreate()
//将ip.txt读取到HDFS中
import spark.implicits._
val rulesLines: Dataset[String] = spark.read.textFile("数据源地址")
//整理ip规则数据
//这里是在Executor中执行的,每个Executor只计算部分的IP规则数据
val ipRulesDataset: Dataset[(Long, Long, String)] = rulesLines.map(line => {
val fields = line.split("[|]")
val startNum = fields(2).toLong
val endNum = fields(3).toLong
val province = fields(6)
(startNum, endNum, province)
})
val ipRulesInDriver: Array[(Long, Long, String)] = ipRulesDataset.collect()
//将IP规则广播出去
val ipRulesBroadcastRef: Broadcast[Array[(Long, Long, String)]] = spark.sparkContext.broadcast(ipRulesInDriver)
//接下来开始读取访问日志数据
val accessLines: Dataset[String] = spark.read.textFile("数据源地址")
//整理日志文件的数据,取出ip,转换成十进制,与IP规则进行比较(采用二分法)
val ips: Dataset[Long] = accessLines.map(line => {
val fields = line.split("[|]")
val ip = fields(1)
//将ip转换成十进制
val ipNum = TestIp.ip2Long(ip)
ipNum
})
val ipDataFrame: DataFrame = ips.toDF("ipNum")
//创建视图
ipDataFrame.createTempView("v_ipNum")
//定义一个自定义函数(UDF),并注册
//该函数的功能是(输入一个IP地址对应的十进制,返回一个省份名称)
spark.udf.register("ip2Province",(ipNum:Long) => {
//查找IP规则(事先已经广播了,已经在Executor中了)
//函数的逻辑是在Executor中执行的,使用广播变量的引用,就可以获得IP规则对应的数据
val ipRulesInExecutor: Array[(Long, Long, String)] = ipRulesBroadcastRef.value
//根据IP地址对应的十进制查找省份名称
val index: Int = TestIp.binarySearch(ipRulesInExecutor,ipNum)
var province = "未知"
if (index != -1){
province = ipRulesInExecutor(index)._3
}
province
})
//执行SQL
val result: DataFrame = spark.sql("SELECT ip2Province(ipNum) province,COUNT(*) counts FROM v_ipNum GROUP BY province ORDER BY counts DESC")
result.show()
//释放资源
spark.stop()
}
}
在Spark中,自定义聚合函数要继承UserDefinedAggregateFunction这个抽象类,重写里面的方法。
先来看一下这个类的源码:
abstract class UserDefinedAggregateFunction extends Serializable {
/**
* A `StructType` represents data types of input arguments of this aggregate function.
* For example, if a [[UserDefinedAggregateFunction]] expects two input arguments
* with type of `DoubleType` and `LongType`, the returned `StructType` will look like
*
* ```
* new StructType()
* .add("doubleInput", DoubleType)
* .add("longInput", LongType)
* ```
*
* The name of a field of this `StructType` is only used to identify the corresponding
* input argument. Users can choose names to identify the input arguments.
*
* @since 1.5.0
*/
def inputSchema: StructType
/**
* A `StructType` represents data types of values in the aggregation buffer.
* For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
* (i.e. two intermediate values) with type of `DoubleType` and `LongType`,
* the returned `StructType` will look like
*
* ```
* new StructType()
* .add("doubleInput", DoubleType)
* .add("longInput", LongType)
* ```
*
* The name of a field of this `StructType` is only used to identify the corresponding
* buffer value. Users can choose names to identify the input arguments.
*
* @since 1.5.0
*/
def bufferSchema: StructType
/**
* The `DataType` of the returned value of this [[UserDefinedAggregateFunction]].
*
* @since 1.5.0
*/
def dataType: DataType
/**
* Returns true iff this function is deterministic, i.e. given the same input,
* always return the same output.
*
* @since 1.5.0
*/
def deterministic: Boolean
/**
* Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer.
*
* The contract should be that applying the merge function on two initial buffers should just
* return the initial buffer itself, i.e.
* `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`.
*
* @since 1.5.0
*/
def initialize(buffer: MutableAggregationBuffer): Unit
/**
* Updates the given aggregation buffer `buffer` with new input data from `input`.
*
* This is called once per input row.
*
* @since 1.5.0
*/
def update(buffer: MutableAggregationBuffer, input: Row): Unit
/**
* Merges two aggregation buffers and stores the updated buffer values back to `buffer1`.
*
* This is called when we merge two partially aggregated data together.
*
* @since 1.5.0
*/
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
/**
* Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given
* aggregation buffer.
*
* @since 1.5.0
*/
def evaluate(buffer: Row): Any
可以看出,继承这个类之后,要重写里面的八个方法。
每个方法代表的含义是:
inputSchema:输入数据的类型
bufferSchema:产生中间结果的数据类型
dataType:最终返回的结果类型
deterministic:确保一致性(输入什么类型的数据就返回什么类型的数据),一般用true
initialize:指定初始值
update:每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
merge:全局聚合(将每个分区的结果进行聚合)
evaluate:计算最终的结果
案例:求几何平均数
package XXX
import java.lang
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
/**
* Create by 。。。
*
*/
object UdafTest {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("IpLocationSQL")
.master("local[4]") //设置本地模式运行
.getOrCreate()
//使用自定义聚合函数
val geoMean = new GeoMean
//测试数据,创建一个Dataset
val range: Dataset[lang.Long] = spark.range(1,11)
//==========使用SQL方式===========
//将range这个Dataset注册成视图
// range.createTempView("v_range")
// //注册我们自定义的聚合函数
// spark.udf.register("gm",geoMean)
// //书写SQL
// val result: DataFrame = spark.sql("SELECT gm(id) result FROM v_range")
//===========使用DSL方式============
import spark.implicits._
val result: DataFrame = range.groupBy().agg(geoMean($"id").as("result"))
//展示结果
result.show()
spark.stop()
}
}
class GeoMean extends UserDefinedAggregateFunction{
//输入数据的类型
override def inputSchema: StructType = StructType(List(
StructField("value",DoubleType)
))
//产生中间结果的数据类型
override def bufferSchema: StructType = StructType(List(
//相乘之后返回的积
StructField("product",DoubleType),
//参与计算数字的个数
StructField("counts",LongType)
))
//最终返回的结果类型
override def dataType: DataType = DoubleType
//确保一致性,一般用true
override def deterministic: Boolean = true
//指定初始值
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//相乘的初始值
buffer(0) = 1.0
//参与运算数字的个数的初始值
buffer(1) = 0L
}
//每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//每有一个数字参与运算就进行相乘(取出上一次的中间结果,再乘上这次的输入)
buffer(0) = buffer.getDouble(0) * input.getDouble(0)
//参与运算数据的个数也有更新
buffer(1) = buffer.getLong(1) + 1L
}
//全局聚合(将每个分区的结果进行聚合)
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//每个分区计算的结果进行相乘
buffer1(0) = buffer1.getDouble(0) * buffer2.getDouble(0)
//每个分区计算的参与运算的个数进行相加
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
//计算最终的结果
override def evaluate(buffer: Row): Any = {
//分别取出buffer中的相乘结果,以及数据的个数,然后求出数据个数的倒数,计算几何平均数
math.pow(buffer.getDouble(0),1.toDouble / buffer.getLong(1))
}
}