Spark UDTF的定义与使用

UDTF概述

  • UDTF(一进多出):对每个列中的每一个元素进行操作,返回一个列(行转列)
  • UDF(一进一出):UDF的定义与使用
  • UDAF(多进一出):UDTF的定义与使用

UDTF的定义

//创建class类继承GenericUDTF,重写initialize、process、close
class UDTF类名 extends GenericUDTF {
     }

UDTF的使用

//在获取SparkSession实例时需要加上.enableHiveSupport(),否则无法使用
val spark = SparkSession.builder().appName("UDTF").master("local[*]").enableHiveSupport().getOrCreate()

//注册UDTF
spark.sql("CREATE TEMPORARY FUNCTION 自定义UDTF别名 AS 'UDTF类名'")

UDTF示例

/*
UDTF.txt测试数据
01//zs//Hadoop scala
02//ls//Hadoop kafka
03//ww//spark hive sqoop
*/
object SparkUDTFDemo {
     
  def main(args: Array[String]): Unit = {
     
  //在获取SparkSession实例时加上enableHiveSupport
    val spark = SparkSession.builder().appName("UDTF").master("local[*]").enableHiveSupport().getOrCreate()
    import spark.implicits._
    val sc = spark.sparkContext
    val rdd = sc.textFile("in/UDTF.txt")
    val df = rdd.map(_.split("//")).map(x=>(x(0),x(1),x(2))).toDF("id","name","class")
    df.createOrReplaceTempView("student")
    //注册UDTF,如果报错说找不到UDTF类,可像我这里写的一样,加上包名nj.kgc.类名
    spark.sql("CREATE TEMPORARY FUNCTION udtf AS 'nj.kgc.myUDTF'")
    //对比原始
    spark.sql("select name,class from student").show()
    /*
    +----+----------------+
	|name|           class|
	+----+----------------+
	|  zs|    Hadoop scala|
	|  ls|    Hadoop kafka|
	|  ww|spark hive sqoop|
	+----+----------------+
    */
    //使用UDTF后
    spark.sql("select name,udtf(class) from student").show()
    /*
    +----+------+
	|name|  type|
	+----+------+
	|  zs|Hadoop|
	|  zs| scala|
	|  ls|Hadoop|
	|  ls| kafka|
	|  ww| spark|
	|  ww|  hive|
	|  ww| sqoop|
	+----+------+
    */

  }

}
//创建UDTF类继承GenericUDTF并重写下面的方法
class myUDTF extends GenericUDTF {
     
	
  override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
     
   	//判断传入的参数是否只有一个
    if (argOIs.length != 1) {
     
      throw new UDFArgumentException("有且只能有一个参数")
    }
    //判断参数类型
    if (argOIs(0).getCategory != ObjectInspector.Category.PRIMITIVE) {
     
      throw new UDFArgumentException("参数类型不匹配")
    }
    val fieldNames = new util.ArrayList[String]
    val fieldOIs = new util.ArrayList[ObjectInspector]
    fieldNames.add("type")
    fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
    ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
  }

  override def process(objects: Array[AnyRef]): Unit = {
     
  	//将传入的数据拆分,形成一个Array数组
    val strings: Array[String] = objects(0).toString.split(" ")
    //遍历集合
    for (elem <- strings) {
     
   	 //每次循环都创建一个新数组,长度为1
      val tmp = new Array[String](1)
      //将循环到的数据传入数组
      tmp(0) = elem
      //显示出去,必须传入的时数组
      forward(tmp)
    }
  }

//关闭方法,这里就不写了
  override def close(): Unit = {
     }
}

你可能感兴趣的:(菜鸟也学大数据,Spark,spark,大数据,udf)