类似有Hive当中的自定义函数,Spark同样可以使用自定义的函数来实现新的功能
Spark中的自定义函数有三类:
UDF(User-Defined-Function)
输入一行,输出一行
UDAF(User-Defined Aggregation Function)
输入多行,输出一行
UDTF(User-Defined Table-Generating Functions)
输入一行,输出多行
alice jogging,Coding,cooking
lina travel,dance
package nj.zb.kb09.sql
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.StringTrimRight
import org.apache.spark.sql.{
DataFrame, SparkSession}
case class Hobbies(name:String,hobbies:String)
object SparkUDFDemo {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().master("local[*]").appName("SparkUDFDemo").getOrCreate()
import spark.implicits._
val sc: SparkContext = spark.sparkContext
val rdd: RDD[String] = sc.textFile("in/hobbies.txt")
val df: DataFrame = rdd.map(x=>x.split(" ")).map(x=>Hobbies(x(0),x(1))).toDF()
df.printSchema()
df.show()
df.registerTempTable("hobbies")
spark.udf.register("hobby_num",(v:String)=>v.split(",").size)
val frame: DataFrame = spark.sql(""+"select name,hobbies,hobby_num(hobbies) as hobbynum from hobbies")
frame.show()
}
}
Hello
abc
study
small
package nj.zb.kb09.sql
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.StringTrimRight
import org.apache.spark.sql.{
DataFrame, Dataset, SparkSession}
object SparkUDFDemo2 {
def main(args: Array[String]): Unit = {
//创建SparkSession
val spark: SparkSession = SparkSession.builder().master("local[*]").appName("SparkUDFDemo2").getOrCreate()
val sc: SparkContext = spark.sparkContext
//读取文件
val fileDs: Dataset[String] = spark.read.textFile("in/udf2.txt")
fileDs.printSchema()
fileDs.show()
//注册一个函数名称为smallToBig,作用是传入一个String,返回一个大写的String
spark.udf.register("smallToBig",(str:String)=>str.toUpperCase())
//定义一个视图
fileDs.createOrReplaceTempView("t_word")
//使用自定义的函数
val df: DataFrame = spark.sql("select value,smallToBig(value) from t_word")
df.printSchema()
df.show()
}
}
继承UserDefinedAggregateFunction方法重写说明
inputSchema:输入数据的类型
bufferSchema:产生中间结果的数据类型
dataType:最终返回的结果类型
deterministic:确保一致性,一般用true
initialize:指定初始值
update:每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
merge:全局聚合(将每个分区的结果进行聚合)
evaluate:计算最终的结果
需求:求不同性别的平均年龄
udaf.json
{
"id": 1001, "name": "foo", "sex": "man", "age": 20}
{
"id": 1002, "name": "bar", "sex": "man", "age": 24}
{
"id": 1003, "name": "baz", "sex": "man", "age": 18}
{
"id": 1004, "name": "foo1", "sex": "woman", "age": 17}
{
"id": 1005, "name": "bar2", "sex": "woman", "age": 19}
{
"id": 1006, "name": "baz3", "sex": "woman", "age": 20}
package nj.zb.kb09.sql
import org.apache.spark.SparkContext
import org.apache.spark.sql.{
DataFrame, Row, SparkSession}
import org.apache.spark.sql.expressions.{
MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
//自定义UDAF函数及使用
object SparkUDAFDemo {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().master("local[*]").appName("SparkUDAFDemo").getOrCreate()
import spark.implicits._
val sc: SparkContext = spark.sparkContext
val df: DataFrame = spark.read.json("in/udaf.json")
df.printSchema()
df.show()
//创建并注册自定义udaf函数
val myUdaf = new MyAgeAvgFunction
spark.udf.register("myAvgAge",myUdaf)
//创建临时视图
df.createTempView("userinfo")
//使用自定义的函数
val resultDF: DataFrame = spark.sql("select sex,myAvgAge(age) from userinfo group by sex")
resultDF.printSchema()
resultDF.show()
//使用内置的avg函数
println("-----------------")
spark.sql("select sex,avg(age) from userinfo group by sex").show()
}
}
class MyAgeAvgFunction extends UserDefinedAggregateFunction{
//聚合函数的输入数据结构
override def inputSchema: StructType = {
new StructType().add("age",LongType)
//StructType(StructField("age",LongType)::Nil)
}
//缓存区的数据结构
override def bufferSchema: StructType = {
new StructType().add("sum",LongType).add("count",LongType)
// StructType(StructField("sum",LongType)::StructField("count",LongType)::Nil)
}
//聚合函数返回值数据结构
override def dataType: DataType =DoubleType
//聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
override def deterministic: Boolean = true
//初始化缓冲区
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0L
buffer(1)=0L
}
//给聚合函数传入一条数据进行处理
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0)=buffer.getLong(0)+input.getLong(0)
buffer(1)=buffer.getLong(1)+1
}
//合并聚合函数缓冲区
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//总年龄数
buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
//部个数
buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
}
//计算最终结果
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble/buffer.getLong(1)
}
}
01//zs//Hadoop scala spark hive hbase
02//ls//Hadoop scala kafka hive hbase Oozie
03//ww//Hadoop scala spark hive sqoop
package nj.zb.kb09.sql
import java.util
import org.apache.hadoop.hive.ql.exec.UDFArgumentException
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.serde2.objectinspector.{
ObjectInspector, ObjectInspectorFactory, PrimitiveObjectInspector, StructObjectInspector}
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{
DataFrame, SparkSession}
//Hive UDTF函数
class MyUDTF extends GenericUDTF{
//这个方法的作用:1、输入参数校验 2、输出列定义,可以多于1列,相当于可以生成多行多列数据
override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
if(argOIs.length!=1){
throw new UDFArgumentException("有且只能有一个参数传入")
}
if (argOIs(0).getCategory!=ObjectInspector.Category.PRIMITIVE){
throw new UDFArgumentException("参数类型不匹配")
}
val fieldNames=new util.ArrayList[String]
val fieldOIs=new util.ArrayList[ObjectInspector]
fieldNames.add("type")
//这里定义的是输出列字段类型
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,fieldOIs)
}
//传入Hadoop scala kafka hive hbase Oozie
/*输出 HEAD type String
Hadoop
scala
kafka
hive
hbase
Oozie
*/
//这是处理数据的方法,入参数组里只有一行数据,即每次调用process方法只处理一行数据
override def process(objects: Array[AnyRef]): Unit = {
//将字符串切分成单个字符的数组
val strings: Array[String] = objects(0).toString.split(" ")
println(strings)
for (str<-strings){
val tmp: Array[String] = new Array[String](1)
tmp(0)=str
//调用forward方法,必须传字符串数组,即使只有一个元素
forward(tmp)
}
}
override def close(): Unit ={
}
}
object SparkUDTFDemo {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().master("local[*]").appName("SparkUDTFDemo").enableHiveSupport().getOrCreate()
import spark.implicits._
val sc: SparkContext = spark.sparkContext
val lines: RDD[String] = sc.textFile("in/udtf.txt")
val stuDf: DataFrame = lines.map(_.split("//")).filter(x=>x(1).equals("ls")).map(x=>(x(0),x(1),x(2))).toDF("id","name","class")
stuDf.printSchema()
stuDf.show()
stuDf.createOrReplaceTempView("student")
spark.sql("CREATE TEMPORARY FUNCTION MyUDTF AS 'nj.zb.kb09.sql.MyUDTF'")
val resultDF: DataFrame = spark.sql("select MyUDTF(class) from student")
resultDF.printSchema()
resultDF.show()
}
}