最近Spark项目里有这样一个需求:需要从HDFS的某个目录下读入一些文件,这些文件是按照proto文件存储的ProtoMessage,现在需要把它们转换成Parquet存储,以供SQL查询。
syntax = "proto3";
message Person { // 定义Person结构体
enum Gender { // 定义性别,为枚举类型,值为男性(Male)或者女性(Female)
Male = 0;
Female = 1;
}
string name = 1; // 定义Person结构体中的名称属性,类型为string
uint32 age = 2; // 定义Person结构体中的年龄属性,类型为uint32
Gender gender = 3; // 定义Person结构体中的性别属性,类型为Gender
// Parquet中不支持自身嵌套自身的类型,因此将该字段注释
// repeated Person_Entity children = 4; // 定义Person结构体中的孩子属性,类型为Person的列表类型
map education_address_map = 5; // 定义上学阶段->上学地址的map映射
}
怎么生成java文件可以参考:初识Protobuf
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.6")
addSbtPlugin("com.thesamet" % "sbt-protoc" % "0.99.27")
libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.10.0"
name := "protobuf_test"
version := "0.1"
scalaVersion := "2.12.12"
libraryDependencies ++= Seq(
"com.google.protobuf" % "protobuf-java" % "3.5.0",
"com.google.guava" % "guava" % "16.0.1",
"org.apache.spark" %% "spark-core" % "3.2.1" % "provided",
"org.apache.spark" %% "spark-sql" % "3.2.1" % "provided",
"com.thesamet.scalapb" %% "sparksql-scalapb" % "0.11.0-RC1",
)
PB.targets in Compile := Seq(
scalapb.gen() -> (sourceManaged in Compile).value
)
assemblyShadeRules in assembly := Seq(
ShadeRule.rename("com.google.protobuf.**" -> "shadeproto.@1").inAll,
ShadeRule.rename("scala.collection.compat.**" -> "shadecompat.@1").inAll
)
assemblyMergeStrategy in assembly := {
case PathList("META-INF", xs @ _*) => MergeStrategy.discard
case x => MergeStrategy.first
}
之后在终端执行sbt compile, 在target/scala-2.12/src_managed/main/person_entity路径下即可生成由scalapb生成的scala文件。如果sbt compile报错 not found: value PersonEntity之类,可以尝试将由protoc生成的java文件移至src/main/scala目录下。
在这里,以本地目录模拟HDFS目录。以下代码将在/tmp/persons目录下生成6个文件。
import java.io.{BufferedOutputStream, File, FileOutputStream}
import Person.Person_Entity
import com.google.common.io.LittleEndianDataOutputStream
object WriteMultiProtobuf {
def main(args: Array[String]): Unit = {
val dir = "/tmp/persons"
new File(dir).mkdirs()
0 to 5 foreach {
i => {
val file = new File(dir + s"/person-${i}")
val dos = new LittleEndianDataOutputStream(new BufferedOutputStream(new FileOutputStream(file, false)))
val person = Person_Entity.newBuilder()
.setName("cshen")
.setAge(18)
.setGender(Person_Entity.Gender.Male)
.putAllEducationAddressMap(new java.util.HashMap[String, String](){{
put("undergraduate", "JNU")
put("postgraduate", "HUST")
}})
.build()
val size = person.getSerializedSize
println(s"size:${size}")
dos.writeInt(size) // 占用4个字节写person size
println(s"person:\n${person}")
dos.write(person.toByteArray)
dos.close()
}
}
}
}
利用ProtoSQL将collect到Driver的ProtoMessage生成DataFrame,再写出为Parquet。
import java.io.{BufferedInputStream, File, FileInputStream, IOException}
import com.google.common.io.LittleEndianDataInputStream
import org.apache.spark.sql.{SaveMode, SparkSession}
import person_entity.Person //引入的是ProtoSQL编译成的Person类而不是protoc编译成的Person类
import scalapb.spark.ProtoSQL
import scala.collection.mutable.ListBuffer
object ReadMultiProtobufSolution1 {
def main(args: Array[String]): Unit = {
val sparkSession = new SparkSession.Builder().appName("ReadMultiProtobufSolution1")
.master("local[*]")
.config("spark.driver.host", "127.0.0.1")
.getOrCreate()
val dir = "/tmp/persons"
val allFiles = new File(dir).list().map(s"${dir}/" + _)
val personsOnDriver = sparkSession.sparkContext.parallelize(allFiles) //将每个file分发到executor,让executor去读取
.flatMap(filename => { //每个文件里可能写了很多person,需要用flatMap展平
val file = new File(filename)
val dos = new LittleEndianDataInputStream(new BufferedInputStream(new FileInputStream(file)))
val resList = new ListBuffer[person_entity.Person]()
try {
var size = dos.readInt()
while (size != -1) {
val personBytes = new Array[Byte](size)
dos.read(personBytes)
// val person = PersonEntity.Person.parseFrom(personBytes)
val person = Person.parseFrom(personBytes)
resList.append(person)
size = dos.readInt()
}
} catch {
case e: IOException => {
println(s"get IOException: ${e.getMessage}")
}
} finally {
dos.close()
}
resList
})
.collect()
ProtoSQL.createDataFrame(sparkSession, personsOnDriver)
.write.mode(SaveMode.Overwrite)
.parquet("/tmp/solution1")
sparkSession.read.parquet("/tmp/solution1").show(false)
}
}
这种方案有个弊端:即它需要把所有的ProtoMessage都collect到Driver之后再创建DataFrame,之后才能写出为Parquet,这对Driver压力很大。而且当ProtoMessage很大时,从Executor传输到Driver的网络开销也是很大的。 因此最后决定放弃这个方案,寻求第二种解决方案。
在executor读入ProtoMessage之后,直接将其写为Parquet。
import java.io.{BufferedInputStream, File, FileInputStream, IOException}
import PersonEntity.Person
import com.google.common.io.LittleEndianDataInputStream
import org.apache.hadoop.fs.Path
import org.apache.parquet.proto.ProtoParquetWriter
import org.apache.spark.sql.SparkSession
object ReadMultiProtobufSolution2 {
def main(args: Array[String]): Unit = {
val sparkSession = new SparkSession.Builder().appName("ReadMultiProtobufSolution2")
.master("local[*]")
.config("spark.driver.host", "127.0.0.1")
.getOrCreate()
val dir = "/tmp/persons"
val allFiles = new File(dir).list().map(s"${dir}/" + _)
sparkSession.sparkContext.parallelize(allFiles) //将每个file分发到executor,让executor去读取
.foreach(filename => { //每个文件里可能写了很多person,需要用flatMap展平
val file = new File(filename)
val dos = new LittleEndianDataInputStream(new BufferedInputStream(new FileInputStream(file)))
val outputParquetPath = s"/tmp/solution2/${filename.substring(filename.lastIndexOf("/") + 1)}.parquet"
// 指定写出parquet的protoMessage类型及路径
val writer = new ProtoParquetWriter[Person](
new Path(outputParquetPath),
classOf[Person]
)
try {
var size = dos.readInt()
while (size != -1) {
val personBytes = new Array[Byte](size)
dos.read(personBytes)
val person = Person.parseFrom(personBytes)
writer.write(person)
size = dos.readInt()
}
} catch {
case e: IOException => {
println(s"get IOException: ${e.getMessage}")
}
} finally {
dos.close()
writer.close()
}
})
sparkSession.read.parquet("/tmp/solution2").show(false)
}
}
最后项目里使用了方案二,性能提升最大的Job从耗时10小时降到了21分钟。
若你在IDE里直接run main函数报错Exception in thread “main” java.lang.NoClassDefFoundError: org/apache/spark/sql/SparkSession$Builder,则需要在Run/Debug Configurations中修改为如下: