类似于hive当中的自定义函数, spark同样可以使用自定义函数来实现新的功能。
spark中的自定义函数有如下3类
1.UDF(User-Defined-Function)
输入一行,输出一行
2.UDAF(User-Defined Aggregation Funcation)
输入多行,输出一行
3.UDTF(User-Defined Table-Generating Functions)
输入一行,输出多行
●需求
有udf.txt数据格式如下:
Hello
abc
study
small
通过自定义UDF函数将每一行数据转换成大写
select value,smallToBig(value) from t_word
●代码演示
package cn.itcast.sql
import org.apache.spark.SparkContext
import org.apache.spark.sql.{Dataset, SparkSession}
object UDFDemo {
def main(args: Array[String]): Unit = {
//1.创建SparkSession
val spark: SparkSession = SparkSession.builder().master("local[*]").appName("SparkSQL").getOrCreate()
val sc: SparkContext = spark.sparkContext
sc.setLogLevel("WARN")
//2.读取文件
val fileDS: Dataset[String] = spark.read.textFile("D:\\data\\udf.txt")
fileDS.show()
/*
+----------+
| value|
+----------+
|helloworld|
| abc|
| study|
| smallWORD|
+----------+
*/
/*
将每一行数据转换成大写
select value,smallToBig(value) from t_word
*/
//注册一个函数名称为smallToBig,功能是传入一个String,返回一个大写的String
spark.udf.register("smallToBig",(str:String) => str.toUpperCase())
fileDS.createOrReplaceTempView("t_word")
//使用我们自己定义的函数
spark.sql("select value,smallToBig(value) from t_word").show()
/*
+----------+---------------------+
| value|UDF:smallToBig(value)|
+----------+---------------------+
|helloworld| HELLOWORLD|
| abc| ABC|
| study| STUDY|
| smallWORD| SMALLWORD|
+----------+---------------------+
*/
sc.stop()
spark.stop()
}
}
●需求
有udaf.json数据内容如下
{"name":"Michael","salary":3000}
{"name":"Andy","salary":4500}
{"name":"Justin","salary":3500}
{"name":"Berta","salary":4000}
求取平均工资
●继承UserDefinedAggregateFunction方法重写说明
inputSchema:输入数据的类型
bufferSchema:产生中间结果的数据类型
dataType:最终返回的结果类型
deterministic:确保一致性,一般用true
initialize:指定初始值
update:每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
merge:全局聚合(将每个分区的结果进行聚合)
evaluate:计算最终的结果
●代码演示
package cn.itcast.sql
import org.apache.spark.SparkContext
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
object UDAFDemo {
def main(args: Array[String]): Unit = {
//1.获取sparkSession
val spark: SparkSession = SparkSession.builder().appName("SparkSQL").master("local[*]").getOrCreate()
val sc: SparkContext = spark.sparkContext
sc.setLogLevel("WARN")
//2.读取文件
val employeeDF: DataFrame = spark.read.json("D:\\data\\udaf.json")
//3.创建临时表
employeeDF.createOrReplaceTempView("t_employee")
//4.注册UDAF函数
spark.udf.register("myavg",new MyUDAF)
//5.使用自定义UDAF函数
spark.sql("select myavg(salary) from t_employee").show()
//6.使用内置的avg函数
spark.sql("select avg(salary) from t_employee").show()
}
}
class MyUDAF extends UserDefinedAggregateFunction{
//输入的数据类型的schema
override def inputSchema: StructType = {
StructType(StructField("input",LongType)::Nil)
}
//缓冲区数据类型schema,就是转换之后的数据的schema
override def bufferSchema: StructType = {
StructType(StructField("sum",LongType)::StructField("total",LongType)::Nil)
}
//返回值的数据类型
override def dataType: DataType = {
DoubleType
}
//确定是否相同的输入会有相同的输出
override def deterministic: Boolean = {
true
}
//初始化内部数据结构
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
//更新数据内部结构,区内计算
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//所有的金额相加
buffer(0) = buffer.getLong(0) + input.getLong(0)
//一共有多少条数据
buffer(1) = buffer.getLong(1) + 1
}
//来自不同分区的数据进行合并,全局合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) =buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
//计算输出数据值
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble / buffer.getLong(1)
}
}
●需求
有udtf.txt数据内容如下
01//zs//Hadoop scala spark hive hbase
02//ls//Hadoop scala kafka hive hbase Oozie
03//ww//Hadoop scala spark hive sqoop
数据以//分隔,遍历ls学的大数据组件
●代码演示
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("t01").getOrCreate()
val sc = spark.sparkContext
sc.setLogLevel("warn")
val filedata = sc.textFile("file:///F:\\udtf.txt")
val splitdata = filedata.map(_.split("//"))
val filterdata = splitdata.filter(a=>a(1).equals("ls"))
val line = filterdata.map(a=>(a(0),a(1),a(2)))
import spark.implicits._
val dataFrame = line.toDF("id","name","class")
dataFrame.show()
dataFrame.createOrReplaceTempView("student")
//CREATE TEMPORARY FUNCTION 自定义算子名称 as '算子实现类全限定名称'
spark.sql("CREATE TEMPORARY FUNCTION myUDTF as 'myUDTF' ")
spark.sql("select myUDTF(class) from student").show()
}
class myUDTF extends GenericUDTF{
//这个方法的作用:1.输入参数校验 2. 输出列定义,可以多于1列,相当于可以生成多行多列数据
override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
if (argOIs.length != 1) {
throw new UDFArgumentLengthException("UserDefinedUDTF takes only one argument")
}
if (argOIs(0).getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentException("UserDefinedUDTF takes string as a parameter")
}
val fieldNames = new util.ArrayList[String]
val fieldOIs = new util.ArrayList[ObjectInspector]
fieldNames.add("type")
//这里定义的是输出列字段类型
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
}
//这是处理数据的方法,入参数组里只有1行数据,即每次调用process方法只处理一行数据
override def process(args: Array[AnyRef]): Unit = {
//将字符串切分成单个字符的数组
val strLst = args(0).toString.split(" ")
for(i <- strLst){
var tmp:Array[String] = new Array[String](1)
tmp(0) = i
//调用forward方法,必须传字符串数组,即使只有一个元素
forward(tmp)
}
}
override def close(): Unit = {}
}