sparksql 自定义udf、udaf、udtf函数详细案例
1、udf函数
spark.udf.register("prefix1", (name: String) => {
"Name:" + name
})
spark.sql("select *,prefix1(name) from users").show()
2、udaf函数
(1) spark 3.0.0之后
package com.yyds.Test
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
class _01_MyAvgNew extends Aggregator[Double, (Double, Int), Double]{
override def zero: (Double, Int) = (0.0, 0)
override def reduce(b: (Double, Int), a: Double): (Double, Int) = {
(b._1 + a, b._2 + 1)
}
override def merge(b1: (Double, Int), b2: (Double, Int)): (Double, Int) = {
(b1._1 + b2._1, b1._2 + b2._2)
}
override def finish(reduction: (Double, Int)): Double = {
reduction._1 / reduction._2
}
override def bufferEncoder: Encoder[(Double, Int)] = {
Encoders.tuple(Encoders.scalaDouble, Encoders.scalaInt)
}
override def outputEncoder: Encoder[Double] = {
Encoders.scalaDouble
}
}
package com.yyds.Test
import com.yyds.common.SparkUtils
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{DataFrame, SparkSession}
object _01_SparkSql_UDAF {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
val spark: SparkSession = SparkUtils.sparkSessionWithNoHive(this.getClass.getSimpleName)
import org.apache.spark.sql.functions._
spark.udf.register("myAvg", udaf(new _01_MyAvgNew))
val df: DataFrame = spark.read.json("file:\\D:\\works\\spark-sql\\data\\udaf.json")
df.createTempView("students")
spark.sql(
"""
|
|select
| sex,
| myAvg(age) as avg
|from
| students
|group by sex
|
|
""".stripMargin).show(2)
}
}
结果如下:
+-----+------------------+
| sex| avg|
+-----+------------------+
| man|20.666666666666668|
|woman|18.666666666666668|
+-----+------------------+
(2)spark 3.0.0之前
package com.yyds.Test
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}
class _01_MyAvgOld extends UserDefinedAggregateFunction{
override def inputSchema: StructType = {
new StructType().add("age",LongType)
}
override def bufferSchema: StructType = {
new StructType().add("sum",LongType).add("count",LongType)
}
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble / buffer.getLong(1)
}
}
package com.yyds.Test
import com.yyds.common.SparkUtils
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{DataFrame, SparkSession}
object _01_SparkSql_UDAF {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
val spark: SparkSession = SparkUtils.sparkSessionWithNoHive(this.getClass.getSimpleName)
spark.udf.register("myAvg", new _01_MyAvgOld)
val df: DataFrame = spark.read.json("file:\\D:\\works\\spark-sql\\data\\udaf.json")
df.createTempView("students")
spark.sql(
"""
|
|select
| sex,
| myAvg(age) as avg
|from
| students
|group by sex
|
|
""".stripMargin).show(2)
}
}
-- 结果如下
+-----+------------------+
| sex| avg|
+-----+------------------+
| man|20.666666666666668|
|woman|18.666666666666668|
+-----+------------------+
3、udtf函数
package com.yyds.Test
import java.util
import org.apache.hadoop.hive.ql.exec.UDFArgumentException
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory, StructObjectInspector}
class _02_MyUdtf extends GenericUDTF{
override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
if(argOIs.length!=1){
throw new UDFArgumentException("有且只能有一个参数传入")
}
if (argOIs(0).getCategory!=ObjectInspector.Category.PRIMITIVE){
throw new UDFArgumentException("参数类型不匹配")
}
val fieldNames=new util.ArrayList[String]
val fieldOIs=new util.ArrayList[ObjectInspector]
fieldNames.add("type")
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,fieldOIs)
}
override def process(objects: Array[AnyRef]): Unit = {
val strings: Array[String] = objects(0).toString.split("\\s+")
for (str<-strings){
val tmp: Array[String] = new Array[String](1)
tmp(0) = str
forward(tmp)
}
}
override def close(): Unit = {
}
}
package com.yyds.Test
import com.yyds.common.SparkUtils
import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}
object _02_SparkSql_UDTF {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
val spark: SparkSession = SparkUtils.sparkSessionWithHive(this.getClass.getSimpleName)
import spark.implicits._
val sc: SparkContext = spark.sparkContext
val lines: RDD[String] = sc.textFile("file:\\D:\\works\\spark-sql\\data\\udtf.txt")
val stuDf: DataFrame = lines.map(_.split("//")).filter(x=>x(1).equals("ls")).map(x=>(x(0),x(1),x(2))).toDF("id","name","class")
stuDf.printSchema()
stuDf.createOrReplaceTempView("student")
spark.sql("CREATE TEMPORARY FUNCTION MyUDTF AS 'com.yyds.Test._02_MyUdtf'")
val resultDF: DataFrame = spark.sql("select MyUDTF(class) from student")
resultDF.show(50)
}
}
-- 结果如下
+------+
| type|
+------+
|Hadoop|
| scala|
| kafka|
| hive|
| hbase|
| Oozie|
+------+