根据输入输出之间的关系来分类:
UDF —— 输入一行,输出一行
UDAF —— 输入多行,输出一行
UDTF —— 输入一行,输出多行
1、数据
大狗 三国,水浒,红楼
二狗 金瓶梅
二条 西游,唐诗宋词
2、需求:求出每个人的爱好个数
3、实现
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder
.master("local")
.appName(this.getClass.getSimpleName)
.getOrCreate()
import spark.implicits._
val df = spark.sparkContext.textFile("D:\\ssc\\likes.txt")
.map(_.split("\t"))
.map(x=>Likes(x(0),x(1))).toDF()
df.createOrReplaceTempView("t_team")
val teamsLengthUTF = spark.udf.register("teams_length",(input:String)=>{
input.split(",").length
})
println("--------------SQL方式----------------")
spark.sql("select name,teams,teams_length(teams) as teams_length from t_team").show(false)
println("--------------API方式----------------")
df.select($"name",$"teams",teamsLengthUTF($"teams").as("teams_length")).show(false)
spark.stop()
}
case class Likes(name:String,teams:String)
}
如果自定义UDAF需要继承UserDefinedAggregateFunction
SparkSql自带的聚和函数:
/**
* 多进一出 udaf
* @param spark
*/
def udafWithSum(spark:SparkSession): Unit = {
val rows = new util.ArrayList[Row]()
rows.add(Row("Luck",30,"M"))
rows.add(Row("Jack",60,"M"))
rows.add(Row("Jim",19,"F"))
rows.add(Row("Lily",20,"F"))
val schema = StructType(
List(
StructField("name",StringType,false),
StructField("age",IntegerType,false),
StructField("sex",StringType,false)
)
)
val df = spark.createDataFrame(rows,schema)
df.createOrReplaceTempView("user")
spark.sql("select sex,sum(age) from user group by sex").show(false)
}
自定义UDAF
需求:男性和女性各自的平均年龄
//avg = sum/参与计算的个数
object JimAvgUDAF extends UserDefinedAggregateFunction {
/**
* 输入类型
* @return
*/
override def inputSchema: StructType = {
StructType(
StructField("nums",DoubleType,true) :: Nil
)
}
/**
* 定义聚合过程中所处理的数据类型
* @return
*/
override def bufferSchema: StructType = {
StructType(
StructField("buffer1",DoubleType,true) :: //年龄总和
StructField("buffer2",LongType,true) :: Nil //参与计算的人数
)
}
/**
* 输出数据类型
* @return
*/
override def dataType: DataType = DoubleType
/**
* 规定一致性
* @return
*/
override def deterministic: Boolean = true
/**
* 初始化数据
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0,0.0)
buffer.update(1,0L)
}
/**
* 分区内聚和
* @param buffer
* @param input
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0,buffer.getDouble(0) + input.getDouble(0))
buffer.update(1,buffer.getLong(1) + 1)
}
/**
* 全局聚合
* @param buffer1
* @param buffer2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0,buffer1.getDouble(0) + buffer2.getDouble(0))
buffer1.update(1,buffer1.getLong(1) + buffer2.getLong(1))
}
/**
* 最终结果
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Any = {
buffer.getDouble(0) / buffer.getLong(1)
}
}
调用
/**
* select sex,jim_avg(age) from user group by sex
* @param spark
*/
def udafAvgWithSex(spark:SparkSession): Unit = {
val rows = new util.ArrayList[Row]()
rows.add(Row("Luck",30,"M"))
rows.add(Row("Jack",60,"M"))
rows.add(Row("Jim",19,"F"))
rows.add(Row("Lily",20,"F"))
val schema = StructType(
List(
StructField("name",StringType,false),
StructField("age",IntegerType,false),
StructField("sex",StringType,false)
)
)
val df = spark.createDataFrame(rows,schema)
df.createOrReplaceTempView("user")
spark.udf.register("avg_udaf",JimAvgUDAF)
spark.sql("select sex,avg_udaf(age) as avg_age from user group by sex").show(false)
}
一进多出
def udtfExplode(spark:SparkSession): Unit = {
val rows = new util.ArrayList[Row]()
rows.add(Row("Luck","Java,JavaScript,Scala"))
rows.add(Row("Jack","History,English,Math"))
val schema = StructType(
List(
StructField("teacher",StringType,false),
StructField("courses",StringType,false)
)
)
val df = spark.createDataFrame(rows,schema)
// implicit val encoder = org.apache.spark.sql.Encoders.kryo[(String,String)]
import spark.implicits._
val courseDs = df.flatMap(row => {
val list = new ListBuffer[Course]()
val courses = row.getString(1).split(",")
for (course <- courses) {
list.append(Course(row.getString(0), course))
}
list
})
courseDs.printSchema()
courseDs.show(false)
}
case class Course(teacher:String,course: String)