Spark系列--SparkSQL(五)用户自定义函数

一、用户自定义UDF函数

通过spark.udf功能用户可以自定义函数。

scala> val df = spark.read.json("/input/people.json")

scala> spark.udf.register("addName",(x:String)=>"Name:"+x)
res18: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(,StringType,Some(List(StringType)))

scala> df.createOrReplaceTempView("people")

scala> spark.sql("Select addName(name), age from people").show()
+-----------------+---+
|UDF:addName(name)|age|
+-----------------+---+
|        Name:Mina| 19|
|        Name:Andy| 30|
|     Name:Michael| 29|
+-----------------+---+

二、用户自定义聚合函数

(1)弱类型用户自定义聚合函数

通过继承UserDefinedAggregateFunction来实现用户自定义聚合函数。下面展示一个求平均工资的自定义聚合函数。

employees.json文件如下:

{"name":"Michael", "salary":3000}
{"name":"Andy", "salary":4500}
{"name":"Justin", "salary":3500}
{"name":"Berta", "salary":4000}
import org.apache.spark.sql.expressions.MutableAggregationBuffer

import org.apache.spark.sql.expressions.UserDefinedAggregateFunction

import org.apache.spark.sql.types._
import org.apache.spark.sql.Row

import org.apache.spark.sql.SparkSession

object MyAverage extends UserDefinedAggregateFunction {


// 聚合函数输入参数的数据类型 

def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)


// 聚合缓冲区中值得数据类型 

def bufferSchema: StructType = {
StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
}


// 返回值的数据类型 

def dataType: DataType = DoubleType



// 对于相同的输入是否一直返回相同的输出。 


def deterministic: Boolean = true






// 初始化



def initialize(buffer: MutableAggregationBuffer): Unit = {
// 存工资的总额
buffer(0) = 0L
// 存工资的个数
buffer(1) = 0L

} 



// 相同Execute间的数据合并。 


def update(buffer: MutableAggregationBuffer, input: Row): Unit = {



if (!input.isNullAt(0)) { 


buffer(0) = buffer.getLong(0) + input.getLong(0)



buffer(1) = buffer.getLong(1) + 1




}



}


// 不同Execute间的数据合并 

def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)

buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)


}




// 计算最终结果
def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)


}



// 注册函数

spark.udf.register("myAverage", MyAverage)



//读取工程
val df =spark.read.json("hdfs://hadoop0:9000/input/employees.json")


df.createOrReplaceTempView("employees")
df.show()


val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")


result.show()

(2)强类型用户自定义聚合函数

通过继承Aggregator来实现强类型自定义聚合函数,同样是求平均工资。

import org.apache.spark.sql.expressions.Aggregator

import org.apache.spark.sql.Encoder


import org.apache.spark.sql.Encoders



import org.apache.spark.sql.SparkSession

// 既然是强类型,可能有case类case class Employee(name: String, salary: Long)case class Average(var sum: Long, var count: Long)object MyAverage extends Aggregator[Employee, Average, Double] {// 定义一个数据结构,保存工资总数和工资总个数,初始都为0 def zero: Average = Average(0L, 0L)




// Combine two values to produce a new value. For performance, the function may modify `buffer` 


// and return it instead of constructing a new objectdef 



reduce(buffer: Average, employee: Employee): Average = {

buffer.sum += employee.salary


buffer.count += 1



buffer


}


// 聚合不同execute的结果def merge(b1: Average, b2: Average): Average = {

b1.sum += b2.sum


b1.count += b2.count


b1


} 


// 计算输出


def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count



// 设定之间值类型的编码器,要转换成case类
// Encoders.product是进行scala元组和case类转换的编码器
def bufferEncoder: Encoder[Average] = Encoders.product


// 设定最终输出值的编码器


def outputEncoder: Encoder[Double] = Encoders.scalaDouble


}




import spark.implicits._


val ds = spark.read.json("hdfs://hadoop0:9000/input/employees.json").as[Employee]


ds.show()


averageSalary = MyAverage.toColumn.name("average_salary") 


val result = ds.select(averageSalary)


result.show()

你可能感兴趣的:(Spark)