使用SparkSQL内置函数接口开发StructType/Row转Json函数

需求

 

将DataFrame中的StructType类型字段下的所有内容转换为Json字符串。

spark版本: 1.6.1

 

 

思路

 

  • DataFrame有toJSON方法,可将每个Row都转为一个Json字符串,并返回RDD[String]
  • DataFrame.write.json方法,可将数据写为Json格式文件

跟踪上述两处代码,发现最终都会调用Spark源码中的org.apache.spark.sql.execution.datasources.json.JacksonGenerator类,使用Jackson,根据传入的StructType、JsonGenerator和InternalRow,生成Json字符串。

 

 

开发

 

我们的函数只需传入一个参数,就是需要转换的列,因此需要实现org.apache.spark.sql.catalyst.expressions包下的UnaryExpression。

后续对功能进行了扩展,不是StructType类型的输入也可以转换。

 

package org.apache.spark.sql.catalyst.expressions


import java.io.CharArrayWriter


import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenContext
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratedExpressionCode
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.Metadata
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType


import com.fasterxml.jackson.core.JsonFactory
import org.apache.spark.unsafe.types.UTF8String


/**
 * 将StructType类型的字段转换为Json String
 * @author yizhu.sun 2016年8月30日
 */
case class Json(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {


  override def dataType: DataType = StringType
  override def inputTypes: Seq[DataType] = Seq(child.dataType)


  val inputStructType: StructType = child.dataType match {
    case st: StructType => st
    case _ => StructType(Seq(StructField("col", child.dataType, child.nullable, Metadata.empty)))
  }


  override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess


  // 参照 org.apache.spark.sql.DataFrame.toJSON
  // 参照 org.apache.spark.sql.execution.datasources.json.JsonOutputWriter.writeInternal
  protected override def nullSafeEval(data: Any): UTF8String = {
    val writer = new CharArrayWriter
    val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
    val internalRow = child.dataType match {
      case _: StructType => data.asInstanceOf[InternalRow]
      case _ => InternalRow(data)
    }
    JacksonGenerator(inputStructType, gen)(internalRow)
    gen.flush
    gen.close
    val json = writer.toString
    UTF8String.fromString(
      child.dataType match {
        case _: StructType => json
        case _ => json.substring(json.indexOf(":") + 1, json.lastIndexOf("}"))
      })
  }


  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
    val writer = ctx.freshName("writer")
    val gen = ctx.freshName("gen")
    val st = ctx.freshName("st")
    val json = ctx.freshName("json")
    val typeJson = inputStructType.json
    def getDataExp(data: Any) =
      child.dataType match {
        case _: StructType => s"$data"
        case _ => s"new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(new Object[]{$data})"
      }
    def formatJson(json: String) =
      child.dataType match {
        case _: StructType => s"$json"
        case _ => s"""$json.substring($json.indexOf(":") + 1, $json.lastIndexOf("}"))"""
      }
    nullSafeCodeGen(ctx, ev, (row) => {
      s"""
        | com.fasterxml.jackson.core.JsonGenerator $gen = null;
        | try {
        |   org.apache.spark.sql.types.StructType $st = ${classOf[Json].getName}.getStructType("${typeJson.replace("\"", "\\\"")}");
        |   java.io.CharArrayWriter $writer = new java.io.CharArrayWriter();
        |   $gen = new com.fasterxml.jackson.core.JsonFactory().createGenerator($writer).setRootValueSeparator(null);
        |   org.apache.spark.sql.execution.datasources.json.JacksonGenerator.apply($st, $gen, ${getDataExp(row)});
        |   $gen.flush();
        |   String $json = $writer.toString();
        |   ${ev.value} = UTF8String.fromString(${formatJson(json)});
        | } catch (Exception e) {
        |   ${ev.isNull} = true;
        | } finally {
        |   if ($gen != null) $gen.close();
        | }
       """.stripMargin
    })
  }


}


object Json {


  val structTypeCache = collection.mutable.Map[String, StructType]() // [json, type]


  def getStructType(json: String): StructType = {
    structTypeCache.getOrElseUpdate(json, {
      println(">>>>> get StructType from json:")
      println(json)
      DataType.fromJson(json).asInstanceOf[StructType]
    })
  }


}


 

 

注册

 

注意,SQLContext.functionRegistry的可见性为protected[sql]

 

 

val (name, (info, builder)) = FunctionRegistry.expression[Json]("json")

sqlContext.functionRegistry.registerFunction(name, info, builder)

 

 

 

测试

 

 

val subSchema = StructType(Array(
  StructField("a", StringType, true),
  StructField("b", StringType, true),
  StructField("c", IntegerType, true)))

val schema = StructType(Array(
  StructField("x", subSchema, true)))

val rdd = sc.makeRDD(Seq(Row(Row("12", null, 123)), Row(Row(null, "2222", null))))

val df = sqlContext.createDataFrame(rdd, schema)

df.registerTempTable("df")

import sqlContext.sql

sql("select x, x.a from df").show
sql("select x, x.a from df").printSchema
sql("select json(x), json(x.a) from df").show
sql("select json(x), json(x.a) from df").printSchema



 

结果



+----------------+----+
|x               |a   |
+----------------+----+
|[12,null,123]   |12  |
|[null,2222,null]|null|
+----------------+----+

root
 |-- x: struct (nullable = true)
 |    |-- a: string (nullable = true)
 |    |-- b: string (nullable = true)
 |    |-- c: integer (nullable = true)
 |-- a: string (nullable = true)

>>>>> get StructType from json:
{"type":"struct","fields":[{"name":"a","type":"string","nullable":true,"metadata":{}},{"name":"b","type":"string","nullable":true,"metadata":{}},{"name":"c","type":"integer","nullable":true,"metadata":{}}]}
>>>>> get StructType from json:
{"type":"struct","fields":[{"name":"col","type":"string","nullable":true,"metadata":{}}]}

+------------------+----+
|_c0               |_c1 |
+------------------+----+
|{"a":"12","c":123}|"12"|
|{"b":"2222"}      |null|
+------------------+----+

root
 |-- _c0: string (nullable = true)
 |-- _c1: string (nullable = true)



 

 

需要注意的点

 

  1. 使用SparkSQL自定义函数一般有两种方法,一种是使用开放的api注册简单函数,即调用sqlContext.udf.register方法。另一种就是使用SparkSQL内置函数的注册方法(本例就是使用的这种方法)。前者优势是开发简单,但是实现不了较为复杂的功能,例如本例中需要获取传入的InternalRow的StructType,或者需要实现类似 def fun(arg: Seq[T]): T 这种泛型相关的功能(sqlContext.udf.register的注册方式无法注册返回值为Any的函数)。
  2. 本例中实现genCode函数时遇到了困难,即需要在生成的Java代码中构建StructType对象。这个最终通过序列化的思路解决,先使用StructType.json方法将StructType对象序列化为String,然后在Java代码中调用DataType.fromJson反序列化为StructType对象。

 

你可能感兴趣的:(json,sparksql,udf,gencode)