更新:
2018.3.26 对于每个例子添加了详细的解释,方便理解.
做过kaggle竞赛的应该很熟悉.csv文件了,.csv文件非常方便,但是通常读取的时候,是一次性读取到内存里面的.要是内存小的话,就要想其他的办法了,那就变得很麻烦了.
或者有时候,从硬盘上面直接读取图片啊什么的,因为图片的文件格式,存放位置各种各样等等一些因素,要是想在训练阶段直接这么使用的话,就更加麻烦了.所以,对于数据进行统一的管理是很有必要的.TFRecord就是对于输入数据做统一管理的格式.加上一些多线程的处理方式,使得在训练期间对于数据管理把控的效率和舒适度都好于暴力的方法.
小的任务什么方法差别不大,但是对于大的任务,使用统一格式管理的好处就非常显著了.因此,TFRecord的使用方法很有必要熟悉.
这节并不准备将TFRcord文件的读取,只讲怎么保存为TFRecord文件,读取还涉及到其他的操作,所以之后会和其他的操作一起讲.
本文的顺序是先讲保存TFRecord文件的时候常见的API,然后再举例子在实际中怎么使用这些API.
把记录写入到TFRecords文件的类.
__init__
(path,options=None)
作用:创建一个
TFRecordWriter
对象,这个对象就负责写记录到指定的文件中去了.
参数:
path: TFRecords 文件路径
options: (可选) TFRecordOptions对象
close()
作用:关闭对象.
write(record)
作用:把字符串形式的记录写到文件中去.
参数:
record: 字符串,待写入的记录
这个类是非常重要的,TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的.在这里,不会非常详细的讲这个类,但是会给出常见的使用方法和一些重要函数的解释.其他的细节可以参考文档.
class tf.train.Example
features Magic attribute generated for “features” proto field.
__init__
(**kwargs)
这个函数是初始化函数,会生成一个Example对象,一般我们使用的时候,是传入一个
tf.train.Features
对象进去.
SerializeToString()
作用:把example序列化为一个字符串,因为在写入到TFRcorde的时候,write方法的参数是字符串的.
class tf.train.Features
feature
__init__
(**kwargs)
作用:初始化Features对象,一般我们是传入一个字典,字典的键是一个字符串,表示名字,字典的值是一个tf.train.Feature对象.
class tf.train.Feature
bytes_list
float_list
int64_list
函数:
__init__
(**kwargs)
作用:构造一个Feature对象,一般使用的时候,传入
tf.train.Int64List
,tf.train.BytesList
,tf.train.FloatList
对象.
使用的时候,一般传入一个具体的值,比如学习任务中的标签就可以传进value=tf.train.Int64List
,而图片就可以先转为字符串的格式之后,传入value=tf.train.BytesList
中.
说明:
以上的函数的API都可以对照着例子的代码来熟悉,看在例子中是怎么使用的这些对象.
这里直接写两个非常常见的例子(有多常见你看一下就知道了),来体会一下TFRecord可能的用法,在这里可以并不用知道程序运行每一行的具体含义,但是要大概知道是怎么回事.更加详细的对这两个例子的讲解会在后面的内容.这里的目的就是先感受一下.
在这个例子中使用的.csv文件就是kaggle竞赛里面MNIST手写体识别的train.csv文件,可以在官网上面下载实验一下.
import tensorflow as tf
import numpy as np
import pandas as pd
#---------------------loading data from .csv file------------------------------#
#load .csv file
train_frame=pd.read_csv(filepath_or_buffer="../Mnist/train.csv")
#print(train_frame.head())
train_labels_frame=train_frame.pop(item="label")
#print(train_labels_frame.shape)
train_values=train_frame.values
train_size=train_values.shape[0]
train_labels_values=train_labels_frame.values
#print(train_values[0].shape)
#print(train_values[0].dtype)
#print(train_labels_values[0])
#------------------------------create TFRecord file----------------------------#
writer=tf.python_io.TFRecordWriter(path="train.tfrecords")
for i in range(train_size):
image_raw=train_values[i].tostring()
example=tf.train.Example(
features=tf.train.Features(
feature={
"image_raw":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
"label":tf.train.Feature(int64_list=tf.train.Int64List(value=[train_labels_values[i]]))
}
)
)
writer.write(record=example.SerializeToString())
writer.close()
这里使用的例子还是来自于kaggle,是在kaggle中的CIFAR-10识别比赛中的数据集.
CIFAR-10 - Object Recognition in Images
数据集主要是两个压缩包,train数据集和test数据集,解压之后就是两个文件夹,里面放了一大推图片.
要是不用TFRecord的话,直接处理是比较低效的.那么来看看用TFRecord怎么处理.
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import tensorflow as tf
import pandas as pd
import os
#get the amount of files in folder
def sizeOfFolder(folder_path):
fileNameList = os.listdir(path=folder_path)
size = 0
for fileName in fileNameList:
if (os.path.isfile(path=os.path.join(folder_path, fileName))):
size += 1
return size
#path(path of folder)
#if isTrain=True,labels can't be None
def pics_to_TFRecord(folder_path,labels=None,isTrain=False):
size=sizeOfFolder(folder_path=folder_path)
#train set
if isTrain:
if labels is None:
print("labels can't be None!!!")
return None
if labels.shape[0]!=size:
print("something wrong with shape!!!")
return None
writer=tf.python_io.TFRecordWriter("../data/TFRecords/train.tfrecords")
for i in range(1,size+1):
print("----------processing the ",i,"\'th image----------")
filename=folder_path+str(i)+".png"
img=mpimg.imread(fname=filename)
width=img.shape[0]
print(width)
#trans to string
img_raw=img.tostring()
example=tf.train.Example(
features=tf.train.Features(
feature={
"img_raw":tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
"label":tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[i-1]])),
"width":tf.train.Feature(int64_list=tf.train.Int64List(value=[width]))
}
)
)
writer.write(record=example.SerializeToString())
writer.close()
#test set
else:
writer = tf.python_io.TFRecordWriter("../data/TFRecords/test.tfrecords")
for i in range(1, size + 1):
print("----------processing the ", i, "\'th image----------")
filename = folder_path + str(i) + ".png"
img = mpimg.imread(fname=filename)
width = img.shape[0]
print(width)
# trans to string
img_raw = img.tostring()
example = tf.train.Example(
features=tf.train.Features(
feature={
"img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
"width": tf.train.Feature(int64_list=tf.train.Int64List(value=[width]))
}
)
)
writer.write(record=example.SerializeToString())
writer.close()
train_labels_frame=pd.read_csv("../data/trainLabels.csv")
train_labels_frame_dummy=pd.get_dummies(data=train_labels_frame)
#print(train_labels_frame_dummy)
train_labels_frame_dummy.pop(item="id")
#print(train_labels_frame_dummy)
train_labels_values_dummy=train_labels_frame_dummy.values
#print(train_labels_values_dummy)
train_labels_values=np.argmax(train_labels_values_dummy,axis=1)
#print(train_labels_values)
#write train record
pics_to_TFRecord(folder_path="../data/train/",labels=train_labels_values,isTrain=True)
#write test record
pics_to_TFRecord(folder_path="../data/test/")
上面的代码运行之后,就会在设置的文件夹下面得到两个.tfrecords文件.
其中训练集的有500多M,测试集达到了惊人的3G多.但是为了后面训练的方便,这是值得的.
通过前面两个方法,我们知道可以把你想要的文件或者记录通过或多或少的方法转为TFRecord格式.
那么数据量很大的时候,你会发现,单个TFRecord文件是非常非常大的,这对于硬盘是不小的负担,所以,可以通过存储多个TFRecord文件来解决问题.
import tensorflow as tf
num_files=3
num_instance=100
for i in range(num_files):
print("write ",i," file")
fileName=("test.tfrecords-%.5d-of-%.5d" % (i,num_files))
writer=tf.python_io.TFRecordWriter(path=fileName)
for j in range(num_instance):
print("write ",j," record")
example=tf.train.Example(
features=tf.train.Features(
feature={
"i":tf.train.Feature(int64_list=tf.train.Int64List(value=[i])),
"j":tf.train.Feature(int64_list=tf.train.Int64List(value=[j]))
}
)
)
writer.write(record=example.SerializeToString())
writer.close()