Spark提供大量内置函数供开发者使用,也可以自定义函数使用。
Spark自定义函数步骤:
1、定义函数
2、注册函数
SparkSession.udf.register():只在sql()中有效
functions.udf():对DataFrame API均有效
3、函数调用
在D:\test\t\目录下有文件hobbies.txt,文件内容:
alice jogging,Coding,cooking
lina travel,dance
需求:用户行为喜好个数统计
要求输出格式:
alice jogging,Coding,cooking 3
lina travel,dance 2
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{
DataFrame, SparkSession}
object SparkUDFDemo {
//样例类
case class Hobbies(name:String,hobbies: String)
def main(args: Array[String]): Unit = {
val spark :SparkSession= SparkSession.builder()
.master("local[1]")
.appName("SparkUDFDemo")
.getOrCreate()
val sc:SparkContext = spark.sparkContext
//需要手动导入一个隐式转换,否则RDD无法转换成DF
import spark.implicits._
val rdd:RDD[String] = sc.textFile("D:\\test\\t\\hobbies.txt")
val df:DataFrame = rdd.map(x=>x.split("\t")).map(x=>Hobbies(x(0),x(1))).toDF()
//df.printSchema()
//df.show()
df.registerTempTable("hobbies")
//注册自定义函数,注意是匿名函数
spark.udf.register("hoby_num",(s:String)=>s.split(",").size)
val frame:DataFrame = spark.sql("select name,hobbies,hoby_num(hobbies) as hobnum from hobbies")
frame.show()
}
}
输出:
+-----+--------------------+------+
| name| hobbies|hobnum|
+-----+--------------------+------+
|alice|jogging,Coding,co...| 3|
| lina| travel,dance| 2|
+-----+--------------------+------+
UDAF(User Defined Aggregate Function),即用户定义的聚合函数,聚合函数和普通函数的区别是什么呢,普通函数是接受一行输入产生一个输出,聚合函数是接受一组(一般是多行)输入然后产生一个输出,即将一组的值想办法聚合一下。
UDAF使用:
继承UserDefinedAggregateFunction
使用UserDefinedAggregateFunction的步骤:
自定义类继承UserDefinedAggregateFunction,对每个阶段方法做实现
在spark中注册UDAF,为其绑定一个名字
然后就可以在sql语句中使用上面绑定的名字调用
在D:\test\t\目录下有文件user.json,文件内容:
{"id": 1001, "name": "foo", "sex": "man", "age": 20}
{"id": 1002, "name": "bar", "sex": "man", "age": 24}
{"id": 1003, "name": "baz", "sex": "man", "age": 18}
{"id": 1004, "name": "foo1", "sex": "woman", "age": 17}
{"id": 1005, "name": "bar2", "sex": "woman", "age": 19}
{"id": 1006, "name": "baz3", "sex": "woman", "age": 20}
需求:计算平均年龄。
import org.apache.spark.sql.{
Row, SparkSession, types}
import org.apache.spark.sql.expressions.{
MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
object SparkUDAFDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.master("local[2]")
.appName("SparkUDAFDemo")
.getOrCreate()
import spark.implicits._
val sc = spark.sparkContext
val df = spark.read.json("D:\\test\\t\\user.json")
//创建并注册自定义udaf函数
val myUdaf=new MyAgeAvgFunction
spark.udf.register("myAvgAge",myUdaf)
df.createTempView("userinfo")
val resultDF = spark.sql("select myAvgAge(age) as avg_age from userinfo group by sex")
resultDF.printSchema()
resultDF.show()
}
}
class MyAgeAvgFunction extends UserDefinedAggregateFunction{
// 聚合函数的输入数据结构
override def inputSchema: StructType = {
new StructType().add("age",LongType)
//另一种写法
//StructType(StructField("age",LongType)::Nil)
}
// 缓存区数据结构
override def bufferSchema: StructType = {
new StructType().add("sum",LongType).add("count",LongType)
//另一种写法
// StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
}
// 聚合函数返回值数据结构
override def dataType: DataType = DoubleType
// 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
override def deterministic: Boolean =true
// 初始化缓冲区
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0L
buffer(1)=0L
}
// 给聚合函数传入一条新数据进行处理
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0)=buffer.getLong(0)+input.getLong(0)
buffer(1)=buffer.getLong(1)+1
}
// 合并聚合函数缓冲区
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//总年龄数
buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
//个数
buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
}
// 计算最终结果
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble/buffer.getLong(1)
}
}
结果:
root
|-- avg_age: double (nullable = true)
+------------------+
| avg_age|
+------------------+
|20.666666666666668|
|18.666666666666668|
+------------------+
参考博客:https://www.cnblogs.com/cc11001100/p/9471859.html
(该博客里还记录了另一种方法:继承Aggregator)
在D:\test\t\目录下有文件udtf.txt,文件内容:
01//zs//Hadoop scala spark hive hbase
02//ls//Hadoop scala kafka hive hbase Oozie
03//ww//Hadoop scala spark hive sqoop
需求:将ls的Hadoop scala kafka hive hbase Oozie生成如下形式:
// type --(表头)
// Hadoop
// scala
// kafaka
// hive
// hbase
// Oozie
import java.util
import org.apache.hadoop.hive.ql.exec.UDFArgumentException
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.serde2.objectinspector.{
ObjectInspector, ObjectInspectorFactory, PrimitiveObjectInspector, StructObjectInspector}
import org.apache.spark.sql.SparkSession
object SparkUDTFDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.master("local[1]")
.enableHiveSupport() //需要hive支持
.appName("SparkUDTFDemo")
.getOrCreate()
val sc = spark.sparkContext
import spark.implicits._
val lines = sc.textFile("D:\\test\\t\\udtf.txt")
val stuDF = lines.map(_.split("//")).filter(x => x(1).equals("ls"))
.map(x => (x(0), x(1), x(2))).toDF("id", "name", "class")
//stuDF.printSchema()
//stuDF.show()
stuDF.createTempView("student")
spark.sql("CREATE TEMPORARY FUNCTION myUDTF AS 'kb09.sql.myUDTF'")
//注意AS后面的类如果在包里一定要加包名!!!
val resultDF = spark.sql("select myUDTF(class) from student")
resultDF.show()
}
}
class myUDTF extends GenericUDTF{
override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
if (argOIs.length!=1){
throw new UDFArgumentException("有且只能有一个参数传入")
}
if (argOIs(0).getCategory!=ObjectInspector.Category.PRIMITIVE){
throw new UDFArgumentException("参数类型不匹配")
}
val fieldNames =new util.ArrayList[String]
val fieldOIs =new util.ArrayList[ObjectInspector]
fieldNames.add("type")
//这里定义的是输出字段的类型
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,fieldOIs)
}
//传入 Hadoop scala kafaka hive hbase Oozie
override def process(objects: Array[AnyRef]): Unit ={
//将字符串切分成单个字符的数组
val strings = objects(0).toString.split(" ")
println(strings)
for (str<- strings){
val tmp = new Array[String](1)
tmp(0)=str
forward(tmp)
}
}
override def close(): Unit = {
}
}
输出:
[Ljava.lang.String;@6d0e1408
+------+
| type|
+------+
|Hadoop|
| scala|
| kafka|
| hive|
| hbase|
| Oozie|
+------+