UDF输入一条记录,输出一条记录,一对一的关系,有点类似于map算子,是一对一的关系
UDF的使用有两种方式,一种方式是在SQL中使用,另一种方式是在DSL方式使用
val myconcat3 = sparkSession.udf.register("myconcat3",new MyConcat)
sparkSession.sql("select myconcat3(id,timestamp) as newid,temperature")
val myconcat3 = sparkSession.udf.register("myconcat3",new MyConcat)
df.select(myconcat3($"id",$"timestamp") as "newid",$"temperature").show()
UDF函数有3中实现方式:
val myconcat = sparkSession.udf.register("myconcat",(data1:String,data2:Long)=>{
data1.concat(data2.toString)
})
//引入udf方法
import org.apache.spark.sql.functions.{col, udf}
val myconcat2 = udf[String,String,Long]((data1:String,data2:Long)=>{
data1.concat(data2.toString)
})
说明: udf函数可以有多个输入参数,如上述我们实现的是两个输入参数则udf的原型是udf[R,T1,T2]
R表示返回值的类型,T1表示第一个参数的类型,T2表示第2个参数的类型
class MyConcat extends Function2[String,Long,String] with Serializable{
override def apply(v1: String, v2: Long): String = {
v1.concat(v2.toString)
}
}
说明:Function2[T1,T2,R]
package com.hjt.yxh.hw.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, udf}
object UDFApp {
def main(args: Array[String]): Unit = {
val conf:SparkConf = new SparkConf()
conf.setMaster("local[*]").setAppName("UDFApp")
val sparkSession:SparkSession = SparkSession.builder()
.config(conf)
.config("spark.sql.legacy.charVarcharAsString",true)
.getOrCreate()
import sparkSession.implicits._
val inpath = "D:\\javaworkspace\\BigData\\Spark\\SparkApp\\src\\main\\resources\\sensor.txt"
val df = sparkSession.read
.format("csv")
.schema("id VARCHAR(32),timestamp BIGINT,temperature DECIMAL(5,2)")
.load(inpath)
df.show(false)
//方式一,使用匿名函数实现
val myconcat = sparkSession.udf.register("myconcat",(data1:String,data2:Long)=>{
data1.concat(data2.toString)
})
df.select(myconcat($"id",$"timestamp") as "newid",$"temperature").show()
//方式二、使用udf函数实现
val myconcat2 = udf[String,String,Long]((data1:String,data2:Long)=>{
data1.concat(data2.toString)
})
sparkSession.udf.register("myconcat2",myconcat2)
df.select(myconcat2($"id",$"timestamp") as "newid",$"temperature").show()
//方式三、继承类的方式
val myconcat3 = sparkSession.udf.register("myconcat3",new MyConcat)
df.select(myconcat3($"id",$"timestamp") as "newid",$"temperature").show()
sparkSession.close()
}
class MyConcat extends Function2[String,Long,String] with Serializable{
override def apply(v1: String, v2: Long): String = {
v1.concat(v2.toString)
}
}
}
Tips:我们在指定schema中使用VARCHAR类型时会报错,因为spark默认是没有开启VARCHAR或者CHAR类型支持的,需要设置一下参数spark.sql.legacy.charVarcharAsString为true
val sparkSession:SparkSession = SparkSession.builder()
.config(conf)
.config("spark.sql.legacy.charVarcharAsString",true)
.getOrCreate()
UDAF是用户自定义聚合函数,一次输入多行做聚合运算输出一个聚合值作为输出结果。
UDF的使用有两种方式,一种方式是在SQL中使用,另一种方式是在DSL方式使用
val myavg = udaf(new MyAggregator)
sparkSession.udf.register("myavg",myavg)
sparkSession.sql("select id,myavg(temperature) from sensor group by id ").show()
val myavg = udaf(new MyAggregator)
sparkSession.udf.register("myavg",myavg)
ds.groupBy("id").agg(myavg($"temperature")).show()
自定义Aggregator的方式需要继承实现Aggregator[IN,BUF,OUT]类。
* IN 输入到聚合运算中的数据类型
* BUF 聚合缓冲区的数据类型
* OUT 结果输出数据类型
弱类型的UDAF的实现
case class AggBuffer(var count:Long,var sum:Double)
class MyAggregator extends Aggregator[Double,AggBuffer,Double]{
//初始化buffer
override def zero: AggBuffer = {
AggBuffer(0,0)
}
//reduce操作
override def reduce(b: AggBuffer, a: Double): AggBuffer = {
b.count+=1
b.sum+=a
b
}
//合并多个buffer
override def merge(b1: AggBuffer, b2: AggBuffer): AggBuffer = {
b1.sum+= b2.sum
b1.count += b2.count
b1
}
//取结果输出
override def finish(reduction: AggBuffer): Double = {
reduction.sum/reduction.count
}
//buffer的序列化器(对于自定义的case类是固定写法,Encoders.product)
override def bufferEncoder: Encoder[AggBuffer] = Encoders.product
//输出的序列化器(对于自定义的case类是固定写法,Encoders.product)
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
package com.hjt.yxh.hw.sparksql
import com.hjt.yxh.hw.transmate.SensorReading
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Dataset, Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions.{avg, udaf}
case class AggBuffer(var count:Long,var sum:Double)
class MyAggregator extends Aggregator[Double,AggBuffer,Double]{
//初始化buffer
override def zero: AggBuffer = {
AggBuffer(0,0)
}
//reduce操作
override def reduce(b: AggBuffer, a: Double): AggBuffer = {
b.count+=1
b.sum+=a
b
}
//合并多个buffer
override def merge(b1: AggBuffer, b2: AggBuffer): AggBuffer = {
b1.sum+= b2.sum
b1.count += b2.count
b1
}
//取结果输出
override def finish(reduction: AggBuffer): Double = {
reduction.sum/reduction.count
}
//buffer的序列化器(对于自定义的case类是固定写法,Encoders.product)
override def bufferEncoder: Encoder[AggBuffer] = Encoders.product
//输出的序列化器(对于自定义的case类是固定写法,Encoders.product)
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
object UDAFApp {
def main(args: Array[String]): Unit = {
val conf:SparkConf = new SparkConf()
conf.setMaster("local[*]")
.setAppName("UDAFApp")
val sparkSession:SparkSession = SparkSession.builder()
.config(conf)
.config("spark.sql.legacy.charVarcharAsString",true)
.getOrCreate()
import sparkSession.implicits._
val inpath = "D:\\javaworkspace\\BigData\\Spark\\SparkApp\\src\\main\\resources\\sensor.txt"
val df = sparkSession.read.format("CSV")
.schema("id VARCHAR(32), timestamp BIGINT,temperature Double")
.load(inpath)
val ds:Dataset[SensorReading] = df.as[SensorReading]
ds.createOrReplaceTempView("sensor")
ds.show()
val myavg = udaf(new MyAggregator)
sparkSession.udf.register("myavg",myavg)
ds.groupBy("id").agg(myavg($"temperature")).show()
sparkSession.stop()
}
}
package com.hjt.yxh.hw.sparksql
import org.apache.spark.SparkConf
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.Aggregator
case class AggBuffer(var count:Long,var sum:Double)
case class SensorReading(id:String,timestamp:BigInt,temperature:Double)
class MyAggregator extends Aggregator[SensorReading,AggBuffer,Double]{
//初始化buffer
override def zero: AggBuffer = {
AggBuffer(0,0)
}
//reduce操作
override def reduce(b: AggBuffer, a: SensorReading): AggBuffer = {
b.count+=1
b.sum+=a.temperature
b
}
//合并多个buffer
override def merge(b1: AggBuffer, b2: AggBuffer): AggBuffer = {
b1.sum+= b2.sum
b1.count += b2.count
b1
}
//取结果输出
override def finish(reduction: AggBuffer): Double = {
reduction.sum/reduction.count
}
//buffer的序列化器(对于自定义的case类是固定写法,Encoders.product)
override def bufferEncoder: Encoder[AggBuffer] = Encoders.product
//输出的序列化器(对于自定义的case类是固定写法,Encoders.product)
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
object UDAFApp {
def main(args: Array[String]): Unit = {
val conf:SparkConf = new SparkConf()
conf.setMaster("local[*]")
.setAppName("UDAFApp")
val sparkSession:SparkSession = SparkSession.builder()
.config(conf)
.config("spark.sql.legacy.charVarcharAsString",true)
.getOrCreate()
import sparkSession.implicits._
val inpath = "D:\\javaworkspace\\BigData\\Spark\\SparkApp\\src\\main\\resources\\sensor.txt"
val df = sparkSession.read.format("CSV")
.schema("id VARCHAR(32), timestamp BIGINT,temperature Double")
.load(inpath)
val ds:Dataset[SensorReading] = df.as[SensorReading]
ds.createOrReplaceTempView("sensor")
val myavg = new MyAggregator().toColumn.name("avg_temperature")
sparkSession.udf.register("myavg",functions.udaf(new MyAggregator()))
val ret = ds.select(myavg)
sparkSession.stop()
}
}
class MyAvg extends UserDefinedAggregateFunction{
//定义输入的数据结构
override def inputSchema: StructType = new StructType()
.add("temperature","double")
//定义缓冲区的数据结构
override def bufferSchema: StructType = new StructType()
.add("count",LongType)
.add("sum",DoubleType)
//定义输出的数据类型
override def dataType: DataType = DataTypes.DoubleType
override def deterministic: Boolean = true
//初始化缓冲区
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0,0L)
buffer.update(1,0.0)
}
//做更新缓冲区操作,相当用reduce
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
var count = buffer.getLong(0)
var sum = buffer.getDouble(1)
count+= 1
sum+= input.getDouble(0)
buffer.update(0,count)
buffer.update(1,sum)
}
//合并两个缓存去
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0,buffer1.getLong(0)+buffer2.getLong(0))
buffer1.update(1,buffer1.getDouble(1)+buffer2.getDouble(1))
}
//计算结果,并输出
override def evaluate(buffer: Row): Any = {
buffer.getDouble(1)/buffer.getLong(0)
}
}
SparkSQL本身其实是没有提供UDTF函数功能的,需要启用Hive支持的方式才能使用。
class MySplit extends GenericUDTF {
override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
//判断传入的参数是否只有一个
if (argOIs.length != 1) {
throw new UDFArgumentException("有且只能有一个参数")
}
//判断参数类型
if (argOIs(0).getCategory != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentException("参数类型不匹配")
}
val fieldNames = new util.ArrayList[String]
val fieldOIs = new util.ArrayList[ObjectInspector]
fieldNames.add("type")
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
}
override def process(objects: Array[AnyRef]): Unit = {
val data_list:Array[String] = objects(0).toString.split("_")
//遍历集合
for(item <- data_list){
val temp = new Array[String](1)
temp(0)=item
forward(temp)
}
}
override def close(): Unit = {}
}
object UDTFApp {
def main(args: Array[String]): Unit = {
val conf:SparkConf = new SparkConf()
conf.setMaster("local[*]").setAppName("UDTFTest")
val sparkSession:SparkSession = SparkSession.builder().enableHiveSupport().config(conf).getOrCreate()
import sparkSession.implicits._
val inpath:String = "D:\\javaworkspace\\BigData\\Spark\\SparkApp\\src\\main\\resources\\sensor.txt"
val df1 = sparkSession.read.format("csv")
.schema("id String,timestamp Bigint,temperature double")
.load(inpath)
//注册表
df1.createOrReplaceTempView("sensor")
//创建函数,使用类名的方式
sparkSession.sql("create TEMPORARY function mySplit as 'com.hjt.yxh.hw.sql.MySplit'")
sparkSession.sql("select mySplit(id) from sensor").show()
sparkSession.stop()
}
}