spark tensorflow tfrecords

依赖

>
            >org.tensorflow>
            >spark-tensorflow-connector_2.11>
            >1.13.1>
            >compile>
        >

spark-tensorflow-connector包 见网盘

读写

  def readTfrecord(spark: SparkSession, path: String): DataFrame ={
    spark.read.format("tfrecords").option("recordType", "Example").load(path)
  }

  def saveAsTfrecord(df: DataFrame, numPart: Int, pathRes: String): Unit ={
    if(numPart<=0){
      df.write.format("tfrecords").option("recordType", "Example").mode("overwrite").save(pathRes)
    }else{
      df.repartition(numPart).write.format("tfrecords").option("recordType", "Example").mode("overwrite").save(pathRes)
    }
  }

你可能感兴趣的:(ML,AI,Deep,Learning,spark)