Spark读入ProtoMessage并写出为Parquet

问题背景

最近Spark项目里有这样一个需求:需要从HDFS的某个目录下读入一些文件,这些文件是按照proto文件存储的ProtoMessage,现在需要把它们转换成Parquet存储,以供SQL查询。

环境准备

1. 准备proto文件:person_entity.proto

syntax = "proto3";
message Person { // 定义Person结构体
    enum Gender { // 定义性别,为枚举类型,值为男性(Male)或者女性(Female)
        Male = 0;
        Female = 1;
    }
    string name = 1; // 定义Person结构体中的名称属性,类型为string
    uint32 age = 2; // 定义Person结构体中的年龄属性,类型为uint32
    Gender gender = 3; // 定义Person结构体中的性别属性,类型为Gender
    // Parquet中不支持自身嵌套自身的类型,因此将该字段注释
    // repeated Person_Entity children = 4; // 定义Person结构体中的孩子属性,类型为Person的列表类型
    map education_address_map = 5; // 定义上学阶段->上学地址的map映射
}

怎么生成java文件可以参考:初识Protobuf

2. 准备插件:project/assembly.sbt

addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.6")
addSbtPlugin("com.thesamet" % "sbt-protoc" % "0.99.27")
libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.10.0"

3. 准备项目依赖文件:build.sbt

name := "protobuf_test"

version := "0.1"

scalaVersion := "2.12.12"

libraryDependencies ++= Seq(
  "com.google.protobuf" % "protobuf-java" % "3.5.0",
  "com.google.guava" % "guava" % "16.0.1",
  "org.apache.spark" %% "spark-core" % "3.2.1" % "provided",
  "org.apache.spark" %% "spark-sql" % "3.2.1" % "provided",
  "com.thesamet.scalapb" %% "sparksql-scalapb" % "0.11.0-RC1",
)

PB.targets in Compile := Seq(
  scalapb.gen() -> (sourceManaged in Compile).value
)
assemblyShadeRules in assembly := Seq(
  ShadeRule.rename("com.google.protobuf.**" -> "shadeproto.@1").inAll,
  ShadeRule.rename("scala.collection.compat.**" -> "shadecompat.@1").inAll
)
assemblyMergeStrategy in assembly := {
  case PathList("META-INF", xs @ _*) => MergeStrategy.discard
  case x => MergeStrategy.first
}

之后在终端执行sbt compile, 在target/scala-2.12/src_managed/main/person_entity路径下即可生成由scalapb生成的scala文件。如果sbt compile报错 not found: value PersonEntity之类,可以尝试将由protoc生成的java文件移至src/main/scala目录下。

4. 准备数据文件

在这里,以本地目录模拟HDFS目录。以下代码将在/tmp/persons目录下生成6个文件。

import java.io.{BufferedOutputStream, File, FileOutputStream}
import Person.Person_Entity
import com.google.common.io.LittleEndianDataOutputStream

object WriteMultiProtobuf {
  def main(args: Array[String]): Unit = {
    val dir = "/tmp/persons"
    new File(dir).mkdirs()
    0 to 5 foreach {
      i => {
        val file = new File(dir + s"/person-${i}")
        val dos = new LittleEndianDataOutputStream(new BufferedOutputStream(new FileOutputStream(file, false)))
        val person = Person_Entity.newBuilder()
          .setName("cshen")
          .setAge(18)
          .setGender(Person_Entity.Gender.Male)
          .putAllEducationAddressMap(new java.util.HashMap[String, String](){{
            put("undergraduate", "JNU")
            put("postgraduate", "HUST")
          }})
          .build()
        val size = person.getSerializedSize
        println(s"size:${size}")
        dos.writeInt(size) // 占用4个字节写person size
        println(s"person:\n${person}")
        dos.write(person.toByteArray)
        dos.close()
      }
    }
  }
}

解决方案一

利用ProtoSQL将collect到Driver的ProtoMessage生成DataFrame,再写出为Parquet。

import java.io.{BufferedInputStream, File, FileInputStream, IOException}

import com.google.common.io.LittleEndianDataInputStream
import org.apache.spark.sql.{SaveMode, SparkSession}
import person_entity.Person //引入的是ProtoSQL编译成的Person类而不是protoc编译成的Person类
import scalapb.spark.ProtoSQL

import scala.collection.mutable.ListBuffer

object ReadMultiProtobufSolution1 {
  def main(args: Array[String]): Unit = {
    val sparkSession = new SparkSession.Builder().appName("ReadMultiProtobufSolution1")
      .master("local[*]")
      .config("spark.driver.host", "127.0.0.1")
      .getOrCreate()
    val dir = "/tmp/persons"
    val allFiles = new File(dir).list().map(s"${dir}/" + _)
    val personsOnDriver = sparkSession.sparkContext.parallelize(allFiles) //将每个file分发到executor,让executor去读取
      .flatMap(filename => { //每个文件里可能写了很多person,需要用flatMap展平
        val file = new File(filename)
        val dos = new LittleEndianDataInputStream(new BufferedInputStream(new FileInputStream(file)))
        val resList = new ListBuffer[person_entity.Person]()
        try {
          var size = dos.readInt()
          while (size != -1) {
            val personBytes = new Array[Byte](size)
            dos.read(personBytes)
//            val person = PersonEntity.Person.parseFrom(personBytes)
            val person = Person.parseFrom(personBytes)
            resList.append(person)
            size = dos.readInt()
          }
        } catch {
          case e: IOException => {
            println(s"get IOException: ${e.getMessage}")
          }
        } finally {
          dos.close()
        }
        resList
      })
      .collect()
    ProtoSQL.createDataFrame(sparkSession, personsOnDriver)
      .write.mode(SaveMode.Overwrite)
      .parquet("/tmp/solution1")
    sparkSession.read.parquet("/tmp/solution1").show(false)
  }
}

这种方案有个弊端:即它需要把所有的ProtoMessage都collect到Driver之后再创建DataFrame,之后才能写出为Parquet,这对Driver压力很大而且当ProtoMessage很大时,从Executor传输到Driver的网络开销也是很大的。 因此最后决定放弃这个方案,寻求第二种解决方案。

解决方案二

在executor读入ProtoMessage之后,直接将其写为Parquet。

import java.io.{BufferedInputStream, File, FileInputStream, IOException}
import PersonEntity.Person
import com.google.common.io.LittleEndianDataInputStream
import org.apache.hadoop.fs.Path
import org.apache.parquet.proto.ProtoParquetWriter
import org.apache.spark.sql.SparkSession

object ReadMultiProtobufSolution2 {
  def main(args: Array[String]): Unit = {
    val sparkSession = new SparkSession.Builder().appName("ReadMultiProtobufSolution2")
      .master("local[*]")
      .config("spark.driver.host", "127.0.0.1")
      .getOrCreate()
    val dir = "/tmp/persons"
    val allFiles = new File(dir).list().map(s"${dir}/" + _)
    sparkSession.sparkContext.parallelize(allFiles) //将每个file分发到executor,让executor去读取
      .foreach(filename => { //每个文件里可能写了很多person,需要用flatMap展平
        val file = new File(filename)
        val dos = new LittleEndianDataInputStream(new BufferedInputStream(new FileInputStream(file)))
        val outputParquetPath = s"/tmp/solution2/${filename.substring(filename.lastIndexOf("/") + 1)}.parquet"
        // 指定写出parquet的protoMessage类型及路径
        val writer = new ProtoParquetWriter[Person](
          new Path(outputParquetPath),
          classOf[Person]
        )
        try {
          var size = dos.readInt()
          while (size != -1) {
            val personBytes = new Array[Byte](size)
            dos.read(personBytes)
            val person = Person.parseFrom(personBytes)
            writer.write(person)
            size = dos.readInt()
          }
        } catch {
          case e: IOException => {
            println(s"get IOException: ${e.getMessage}")
          }
        } finally {
          dos.close()
          writer.close()
        }
      })
    sparkSession.read.parquet("/tmp/solution2").show(false)
  }
}

最后

最后项目里使用了方案二,性能提升最大的Job从耗时10小时降到了21分钟。
若你在IDE里直接run main函数报错Exception in thread “main” java.lang.NoClassDefFoundError: org/apache/spark/sql/SparkSession$Builder,则需要在Run/Debug Configurations中修改为如下:
Spark读入ProtoMessage并写出为Parquet_第1张图片

你可能感兴趣的:(Protobuf,Spark,Scala,spark,大数据,scala,Protobuf)