目前数据越来越多,数据一般存储在hdfs上,但是目前许多深度学习算法是基于TensorFlow、pytorch等框架实现,使用单机python、java做数据转换都比较慢,怎么大规模把hdfs数据直接喂到TensorFlow中,在这里TensorFlow提供了一种解决方案,利用spark生成tfrecord文件,项目名称叫spark-tensorflow-connector,GitHub主页在https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector 这下面,按照readme编译jar包,放在自己的项目里面做依赖既可以使用,如果实在不想自己编译jar包,也可以在这上面直接添加依赖下载https://mvnrepository.com/artifact/org.tensorflow/spark-tensorflow-connector,主要原理是在这个项目里面写了一些隐士转换类类,重写了输出的格式,对上层输出的接口都比较简单,提供了scala、python的接口,实际背后全部是依赖于proto,不得不佩服google的技术的强大以及推广能力,下面看下怎么使用:
package com.xxx.tfrecords
import scala.collection.JavaConversions._;
import scala.collection.JavaConverters._;
import collection.JavaConversions._
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._
object TFRecordsExample {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
val spark = SparkSession.builder().master("local[4]").appName("tfrecords_examples").getOrCreate()
val path = "file/test-output.tfrecord"
val testRows: Array[Row] = Array(
new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")),
new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2")))
val schema = StructType(List(
StructField("id", IntegerType),
StructField("IntegerCol", IntegerType),
StructField("LongCol", LongType),
StructField("FloatCol", FloatType),
StructField("DoubleCol", DoubleType),
StructField("VectorCol", ArrayType(DoubleType, true)),
StructField("StringCol", StringType)))
val rdd = spark.sparkContext.parallelize(testRows)
//Save DataFrame as TFRecords
val df: DataFrame = spark.createDataFrame(rdd, schema)
df.write.format("tfrecords").option("recordType", "Example").save(path)
//Read TFRecords into DataFrame.
//The DataFrame schema is inferred from the TFRecords if no custom schema is provided.
val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load(path)
importedDf1.show()
//Read TFRecords into DataFrame using custom schema
val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
importedDf2.show()
}
}
package com.xxx.tfrecords
import scala.collection.JavaConversions._;
import scala.collection.JavaConverters._;
import collection.JavaConversions._
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._
object TFRcordsBert {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
val spark = SparkSession.builder().master("local[4]").appName("tfrecords_examples").getOrCreate()
val path = "/Users/shuubiasahi/Desktop/textclass/"
val schema = StructType(List(
StructField("input_ids", ArrayType(IntegerType, true)),
StructField("input_mask", ArrayType(IntegerType, true)),
StructField("label_ids", ArrayType(IntegerType, true))))
val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "SequenceExample").load(path)
importedDf1.show()
val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
importedDf2.show()
}
}
+--------------------+--------------------+---------+
| input_ids| input_mask|label_ids|
+--------------------+--------------------+---------+
|[101, 4281, 3566,...|[1, 1, 1, 1, 1, 1...| [25]|
|[101, 3433, 5866,...|[1, 1, 1, 1, 0, 0...| [40]|
|[101, 6631, 5277,...|[1, 1, 1, 1, 1, 1...| [5]|