我们看下如何在 SparkSQL 中 定义并使用 UDTF。
Base spark 2.2.0
Base Hive 2.1.1
Spark 1.*
Hive 2.1.1
目前Spark 内部不直接支持 udtf, 在比较久远的版本 spark 1.* 可以实现通过 hive 的 UDTF, 并注册函数实现。
UDTF class
package com.spark.test.offline.udf
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
/**
* Created by szh on 2020/6/1.
*/
class CustomerUDTF extends GenericUDTF {
override def process(objects: Array[AnyRef]): Unit = {
//将字符串切分成单个字符的数组
val strLst = objects(0).toString.split("")
for (i <- strLst) {
var tmp: Array[String] = new Array[String](1)
tmp(0) = i
//调用forward方法,必须传字符串数组,即使只有一个元素
forward(tmp)
}
}
override def close(): Unit = {}
}
spark 中代码
spark.sqlContext.sql("CREATE TEMPORARY FUNCTION NEWUDTF as 'com.spark.test.offline.udf.CustomerUDTF'")
相关参考文章:
1.error running Hive temporary UDTF on latest Spark 2.2
https://issues.apache.org/jira/browse/SPARK-21101
2.SparkSQL 自定义算子UDF、UDAF、UDTF
https://blog.csdn.net/laksdbaksjfgba/article/details/87162906
实现的逻辑:由1条记录生成10条记录
使用 ftatMap 算子对Row 进行膨胀
完整代码如下:
package com.spark.test.offline.udf
import com.spark.test.offline.optimize.del.User
import org.apache.spark.SparkConf
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}
import scala.collection.mutable.ArrayBuffer
/**
* Created by szh on 2020/6/1.
*/
object SparkSQLUdtf {
def main(args: Array[String]): Unit = {
val conf = new SparkConf
conf
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
// .set("spark.kryo.registrationRequired", "true")
//方法一
.registerKryoClasses(
Array(
classOf[User]
, classOf[scala.collection.mutable.WrappedArray.ofRef[_]]
))
val spark = SparkSession
.builder()
.appName("sparkSql")
.master("local[1]")
.config(conf)
.getOrCreate()
val sc = spark.sparkContext
sc.setLogLevel("ERROR")
val orgRDD = sc.parallelize(Seq(
User(1, "cc", "bj")
, User(2, "aa", "bj")
, User(3, "qq", "bj")
, User(4, "pp", "bj")
))
val orgDF = spark
.createDataFrame(orgRDD)
orgDF.show()
//spark.sqlContext.sql("CREATE TEMPORARY FUNCTION NEWUDTF as 'com.spark.test.offline.udf.CustomerUDTF'")
val midRDD = orgDF.rdd.flatMap(tmp => {
val x = ArrayBuffer[Row]()
for (i <- 1 to 10) {
x.+=:(Row(tmp.getInt(0), tmp.getString(1), tmp.getString(2), i))
}
x
})
println(midRDD.count())
val finalDF = spark.createDataFrame(midRDD, StructType(
Array(
StructField("id", IntegerType)
, StructField("name", StringType)
, StructField("city", StringType)
, StructField("no", IntegerType)
)
))
finalDF.show()
Thread.sleep(10 * 60 * 1000)
sc.stop()
spark.stop()
}
}
代码如下
+---+----+----+
| id|name|city|
+---+----+----+
| 1| cc| bj|
| 2| aa| bj|
| 3| qq| bj|
| 4| pp| bj|
+---+----+----+
40
+---+----+----+---+
| id|name|city| no|
+---+----+----+---+
| 1| cc| bj| 10|
| 1| cc| bj| 9|
| 1| cc| bj| 8|
| 1| cc| bj| 7|
| 1| cc| bj| 6|
| 1| cc| bj| 5|
| 1| cc| bj| 4|
| 1| cc| bj| 3|
| 1| cc| bj| 2|
| 1| cc| bj| 1|
| 2| aa| bj| 10|
| 2| aa| bj| 9|
| 2| aa| bj| 8|
| 2| aa| bj| 7|
| 2| aa| bj| 6|
| 2| aa| bj| 5|
| 2| aa| bj| 4|
| 2| aa| bj| 3|
| 2| aa| bj| 2|
| 2| aa| bj| 1|
+---+----+----+---+
only showing top 20 rows
Maven
spark-test
www.sunzhenhua.com
1.0-SNAPSHOT
4.0.0
spark-offline
org.apache.spark
spark-core_2.11
org.apache.spark
spark-sql_2.11
org.apache.spark
spark-streaming_2.11
org.apache.spark
spark-streaming-kafka-0-10_2.11
org.scala-lang
scala-library
org.apache.hive
hive-exec
2.1.0
src/main/resources
true
target/scala-${scala.binary.version}/classes
target/scala-${scala.binary.version}/test-classes
org.apache.maven.plugins
maven-shade-plugin
package
shade
false
true
*:*
*:*
META-INF/*.SF
META-INF/*.DSA
META-INF/*.RSA
reference.conf