需求
将DataFrame中的StructType类型字段下的所有内容转换为Json字符串。
spark版本: 1.6.1
思路
跟踪上述两处代码,发现最终都会调用Spark源码中的org.apache.spark.sql.execution.datasources.json.JacksonGenerator类,使用Jackson,根据传入的StructType、JsonGenerator和InternalRow,生成Json字符串。
开发
我们的函数只需传入一个参数,就是需要转换的列,因此需要实现org.apache.spark.sql.catalyst.expressions包下的UnaryExpression。
后续对功能进行了扩展,不是StructType类型的输入也可以转换。
package org.apache.spark.sql.catalyst.expressions import java.io.CharArrayWriter import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenContext import org.apache.spark.sql.catalyst.expressions.codegen.GeneratedExpressionCode import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.types.DataType import org.apache.spark.sql.types.Metadata import org.apache.spark.sql.types.StringType import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType import com.fasterxml.jackson.core.JsonFactory import org.apache.spark.unsafe.types.UTF8String /** * 将StructType类型的字段转换为Json String * @author yizhu.sun 2016年8月30日 */ case class Json(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(child.dataType) val inputStructType: StructType = child.dataType match { case st: StructType => st case _ => StructType(Seq(StructField("col", child.dataType, child.nullable, Metadata.empty))) } override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess // 参照 org.apache.spark.sql.DataFrame.toJSON // 参照 org.apache.spark.sql.execution.datasources.json.JsonOutputWriter.writeInternal protected override def nullSafeEval(data: Any): UTF8String = { val writer = new CharArrayWriter val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) val internalRow = child.dataType match { case _: StructType => data.asInstanceOf[InternalRow] case _ => InternalRow(data) } JacksonGenerator(inputStructType, gen)(internalRow) gen.flush gen.close val json = writer.toString UTF8String.fromString( child.dataType match { case _: StructType => json case _ => json.substring(json.indexOf(":") + 1, json.lastIndexOf("}")) }) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val writer = ctx.freshName("writer") val gen = ctx.freshName("gen") val st = ctx.freshName("st") val json = ctx.freshName("json") val typeJson = inputStructType.json def getDataExp(data: Any) = child.dataType match { case _: StructType => s"$data" case _ => s"new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(new Object[]{$data})" } def formatJson(json: String) = child.dataType match { case _: StructType => s"$json" case _ => s"""$json.substring($json.indexOf(":") + 1, $json.lastIndexOf("}"))""" } nullSafeCodeGen(ctx, ev, (row) => { s""" | com.fasterxml.jackson.core.JsonGenerator $gen = null; | try { | org.apache.spark.sql.types.StructType $st = ${classOf[Json].getName}.getStructType("${typeJson.replace("\"", "\\\"")}"); | java.io.CharArrayWriter $writer = new java.io.CharArrayWriter(); | $gen = new com.fasterxml.jackson.core.JsonFactory().createGenerator($writer).setRootValueSeparator(null); | org.apache.spark.sql.execution.datasources.json.JacksonGenerator.apply($st, $gen, ${getDataExp(row)}); | $gen.flush(); | String $json = $writer.toString(); | ${ev.value} = UTF8String.fromString(${formatJson(json)}); | } catch (Exception e) { | ${ev.isNull} = true; | } finally { | if ($gen != null) $gen.close(); | } """.stripMargin }) } } object Json { val structTypeCache = collection.mutable.Map[String, StructType]() // [json, type] def getStructType(json: String): StructType = { structTypeCache.getOrElseUpdate(json, { println(">>>>> get StructType from json:") println(json) DataType.fromJson(json).asInstanceOf[StructType] }) } }
注册
注意,SQLContext.functionRegistry的可见性为protected[sql]
val (name, (info, builder)) = FunctionRegistry.expression[Json]("json") sqlContext.functionRegistry.registerFunction(name, info, builder)
测试
val subSchema = StructType(Array( StructField("a", StringType, true), StructField("b", StringType, true), StructField("c", IntegerType, true))) val schema = StructType(Array( StructField("x", subSchema, true))) val rdd = sc.makeRDD(Seq(Row(Row("12", null, 123)), Row(Row(null, "2222", null)))) val df = sqlContext.createDataFrame(rdd, schema) df.registerTempTable("df") import sqlContext.sql sql("select x, x.a from df").show sql("select x, x.a from df").printSchema sql("select json(x), json(x.a) from df").show sql("select json(x), json(x.a) from df").printSchema
结果
+----------------+----+ |x |a | +----------------+----+ |[12,null,123] |12 | |[null,2222,null]|null| +----------------+----+ root |-- x: struct (nullable = true) | |-- a: string (nullable = true) | |-- b: string (nullable = true) | |-- c: integer (nullable = true) |-- a: string (nullable = true) >>>>> get StructType from json: {"type":"struct","fields":[{"name":"a","type":"string","nullable":true,"metadata":{}},{"name":"b","type":"string","nullable":true,"metadata":{}},{"name":"c","type":"integer","nullable":true,"metadata":{}}]} >>>>> get StructType from json: {"type":"struct","fields":[{"name":"col","type":"string","nullable":true,"metadata":{}}]} +------------------+----+ |_c0 |_c1 | +------------------+----+ |{"a":"12","c":123}|"12"| |{"b":"2222"} |null| +------------------+----+ root |-- _c0: string (nullable = true) |-- _c1: string (nullable = true)
需要注意的点