利用spark生成tfrecord文件

目前数据越来越多,数据一般存储在hdfs上,但是目前许多深度学习算法是基于TensorFlow、pytorch等框架实现,使用单机python、java做数据转换都比较慢,怎么大规模把hdfs数据直接喂到TensorFlow中,在这里TensorFlow提供了一种解决方案,利用spark生成tfrecord文件,项目名称叫spark-tensorflow-connector,GitHub主页在https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector 这下面,按照readme编译jar包,放在自己的项目里面做依赖既可以使用,如果实在不想自己编译jar包,也可以在这上面直接添加依赖下载https://mvnrepository.com/artifact/org.tensorflow/spark-tensorflow-connector,主要原理是在这个项目里面写了一些隐士转换类类,重写了输出的格式,对上层输出的接口都比较简单,提供了scala、python的接口,实际背后全部是依赖于proto,不得不佩服google的技术的强大以及推广能力,下面看下怎么使用:

 

官方例子:

package com.xxx.tfrecords
import scala.collection.JavaConversions._;
import scala.collection.JavaConverters._;
import collection.JavaConversions._
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.SparkConf

import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._

object TFRecordsExample {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
    Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)

    val spark = SparkSession.builder().master("local[4]").appName("tfrecords_examples").getOrCreate()

    val path = "file/test-output.tfrecord"
    val testRows: Array[Row] = Array(
      new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")),
      new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2")))
      
    val schema = StructType(List(
      StructField("id", IntegerType),
      StructField("IntegerCol", IntegerType),
      StructField("LongCol", LongType),
      StructField("FloatCol", FloatType),
      StructField("DoubleCol", DoubleType),
      StructField("VectorCol", ArrayType(DoubleType, true)),
      StructField("StringCol", StringType)))

    val rdd = spark.sparkContext.parallelize(testRows)

    //Save DataFrame as TFRecords
    val df: DataFrame = spark.createDataFrame(rdd, schema)
    df.write.format("tfrecords").option("recordType", "Example").save(path)

    //Read TFRecords into DataFrame.
    //The DataFrame schema is inferred from the TFRecords if no custom schema is provided.
    val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load(path)
    importedDf1.show()

    //Read TFRecords into DataFrame using custom schema
    val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
    importedDf2.show()

  }

}

 

读取bert模型训练的数据测试:

package com.xxx.tfrecords
import scala.collection.JavaConversions._;
import scala.collection.JavaConverters._;
import collection.JavaConversions._
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.SparkConf

import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._
object TFRcordsBert {
  def main(args: Array[String]): Unit = {
    
     Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
    Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)

    val spark = SparkSession.builder().master("local[4]").appName("tfrecords_examples").getOrCreate()

    val path = "/Users/shuubiasahi/Desktop/textclass/"
    val schema = StructType(List(
      StructField("input_ids",  ArrayType(IntegerType, true)),
      StructField("input_mask",  ArrayType(IntegerType, true)),
      StructField("label_ids",  ArrayType(IntegerType, true))))
     
      
   val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "SequenceExample").load(path)
    importedDf1.show()

    val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
    importedDf2.show()

    
    
  }
}

+--------------------+--------------------+---------+
|           input_ids|          input_mask|label_ids|
+--------------------+--------------------+---------+
|[101, 4281, 3566,...|[1, 1, 1, 1, 1, 1...|     [25]|
|[101, 3433, 5866,...|[1, 1, 1, 1, 0, 0...|     [40]|
|[101, 6631, 5277,...|[1, 1, 1, 1, 1, 1...|      [5]|

你可能感兴趣的:(机器学习,apache,spark)