参考文章:How to write into and read from a TFRecords file in TensorFlow
数据集:Dogs vs. Cats
TensorFlow提供了一种统一的格式来存储数据,这个格式是TFRecord。TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer
的格式存储的。以下代码给出了tf.train.Example
的定义。
message Example {
Features features = 1;
};
message Features {
map<string,Feature>feature = 1;
}
message Feature {
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
接着便以猫狗大战数据集为例展示TFRecord的生成和读取。
首先,我们需要将图片和标签列表化。我们让猫的label=0、狗的label=1。以下代码列表化所有的图片,赋予合适的标签,并对数据进行shuffle。同时也将数据集划分成训练集(60%)和验证集(20%)以及测试集(20%)。
from random import shuffle
import glob
shuffle_data = True # shuffle the addresses before saving
cat_dog_train_path = 'Cat vs Dog/train/*.jpg'
# read addresses and labels from the 'train' folder
addrs = glob.glob(cat_dog_train_path)
labels = [0 if 'cat' in addr else 1 for addr in addrs] # 0 = Cat, 1 = Dog
# to shuffle data
if shuffle_data:
c = list(zip(addrs, labels))
shuffle(c)
addrs, labels = zip(*c)
# Divide the hata into 60% train, 20% validation, and 20% test
train_addrs = addrs[0:int(0.6*len(addrs))]
train_labels = labels[0:int(0.6*len(labels))]
val_addrs = addrs[int(0.6*len(addrs)):int(0.8*len(addrs))]
val_labels = labels[int(0.6*len(addrs)):int(0.8*len(addrs))]
test_addrs = addrs[int(0.8*len(addrs)):]
test_labels = labels[int(0.8*len(labels)):]
首先我们要读取图片并将其转化为我们想保存在TFRecords文件中的数据的格式(本例中为float32)。以下函数完成了图片的读取和resize,并返回一个合适的数据格式。
def load_image(addr):
# read an image and resize to (224, 224)
# cv2 load images as BGR, convert it to RGB
img = cv2.imread(addr)
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype(np.float32)
return img
在将数据保存到TFRecords文件之前,我们需要将它放到一个名叫Example的protocol buffer中。接着我们将序列化protocol buffer为string并将它写入TFR文件。Example protocol buffer包含了Features。Feature是一个用于描述数据的protocol,它有三种类型:bytes、float、int64。总而言之,保存你的数据通过以下这些步骤:
1. 使用tf.python_io.TFRecordWriter
打开一个TFRecords文件
2. 使用tf.train.Int64List
、tf.train.BytesList
、tf.train.FloatList
将数据转化为合适类型的feature
3. 使用tf.train.Feature
创建一个feature并将数据传给它
4. 使用tf.train.Example
创建一个Example protocol buffer并将feature传给它
5. 使用example.SerializeToString()
序列化Example为string
6. 将序列化后的example写入:writer.write
本例中我们将使用以下两个函数来创建features:
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
现在讲数据保存到TFRecords文件:
train_filename = 'train.tfrecords' # address to save the TFRecords file
# open the TFRecords file
writer = tf.python_io.TFRecordWriter(train_filename)
for i in range(len(train_addrs)):
# print how many images are saved every 1000 images
if not i % 1000:
print 'Train data: {}/{}'.format(i, len(train_addrs))
sys.stdout.flush()
# Load the image
img = load_image(train_addrs[i])
label = train_labels[i]
# Create a feature
feature = {'train/label': _int64_feature(label),
'train/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))}
# Create an example protocol buffer
example = tf.train.Example(features=tf.train.Features(feature=feature))
# Serialize to string and write on the file
writer.write(example.SerializeToString())
writer.close()
sys.stdout.flush()
类似的,生成验证和测试的TFR文件:
# open the TFRecords file
val_filename = 'val.tfrecords' # address to save the TFRecords file
writer = tf.python_io.TFRecordWriter(val_filename)
for i in range(len(val_addrs)):
# print how many images are saved every 1000 images
if not i % 1000:
print 'Val data: {}/{}'.format(i, len(val_addrs))
sys.stdout.flush()
# Load the image
img = load_image(val_addrs[i])
label = val_labels[i]
# Create a feature
feature = {'val/label': _int64_feature(label),
'val/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))}
# Create an example protocol buffer
example = tf.train.Example(features=tf.train.Features(feature=feature))
# Serialize to string and write on the file
writer.write(example.SerializeToString())
writer.close()
sys.stdout.flush()
# open the TFRecords file
test_filename = 'test.tfrecords' # address to save the TFRecords file
writer = tf.python_io.TFRecordWriter(test_filename)
for i in range(len(test_addrs)):
# print how many images are saved every 1000 images
if not i % 1000:
print 'Test data: {}/{}'.format(i, len(test_addrs))
sys.stdout.flush()
# Load the image
img = load_image(test_addrs[i])
label = test_labels[i]
# Create a feature
feature = {'test/label': _int64_feature(label),
'test/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))}
# Create an example protocol buffer
example = tf.train.Example(features=tf.train.Features(feature=feature))
# Serialize to string and write on the file
writer.write(example.SerializeToString())
writer.close()
sys.stdout.flush()
TensorFlow文件读取机制参考文章:十图详解tensorflow数据读取机制
为了读取TFR文件,有以下步骤:
1. 创建一个文件名列表:本例中我们只有一个文件data_path='train.tfrecords
因此我们的list应该是[data_path]
2. 创建文件名队列:使用tf.train.string_input_producer
创建一个FIFO队列。它需要传入文件名列表,系统会自动将它转化为一个文件名队列。它还有两个重要的参数,一个是num_epochs
来指定epoch,另一个是shuffle
来指定是否打乱顺序。
3. 定义reader:对于TFR文件我们需要定义一个TFRecordReader–reader = tf.TFRecordReader()
。然后reader返回下一个record–reader.read(filename_queue)
4. 定义decoder:reader读出来的record需要经过decoder的解析。TFR文件的decoder应该是tf.parse_single_example
。它需要传入一个序列化的Example和一个dict(key为feature,value为FixedLenFeature或者VarLenFeature),并返回一个dict(key为feature,value为Tensor)–features = tf.parse_single_example(serialized_example,features=feature)
5. 将数据从string转换回数字:tf.decode_raw(bytes,out_type)
传入一个string类型的Tensor,并将它转换为out_type
类型。当然,对于那些没有转化为string 的label,我们只需要使用tf.cast(x,dtype)
6. 将数据reshape到它原本的shape:image = tf.reshape(image,[224,224,3])
7. 预处理:如果想对数据做预处理请在现在完成
8. Batching:另外的队列用来从examples中创建batches。可以使用tf.train.shuffle_batch([image,label], batch_size=10, capacity=30, num_threads=1, min_after_dequeue=10)
,capacity
是队列的最大size,·min_after_dequeue
是出列后队列的最小size,num_threads
是入队example 的线程数目。使用多线程可提高读取速度。
读取TFR文件的代码:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
data_path = 'train.tfrecords' # address to save the hdf5 file
with tf.Session() as sess:
feature = {'train/image': tf.FixedLenFeature([], tf.string),
'train/label': tf.FixedLenFeature([], tf.int64)}
# Create a list of filenames and pass it to a queue
filename_queue = tf.train.string_input_producer([data_path], num_epochs=1)
# Define a reader and read the next record
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# Decode the record read by the reader
features = tf.parse_single_example(serialized_example, features=feature)
# Convert the image data from string back to the numbers
image = tf.decode_raw(features['train/image'], tf.float32)
# Cast label data into int32
label = tf.cast(features['train/label'], tf.int32)
# Reshape image data into the original shape
image = tf.reshape(image, [224, 224, 3])
# Any preprocessing here ...
# Creates batches by randomly shuffling tensors
images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, num_threads=1, min_after_dequeue=10)
初始化全局变量和局部变量
一些函数如tf.train
和tf.train.shuffle_batch
添加了tf.train.QueueRunner
对象到图中.每个这样的对象都维持了一个列表的入队op。在我们使用tf.train.string_input_producer
创建文件名队列后,整个系统其实还是处于“停滞状态”的,也就是说,我们文件名并没有真正被加入到队列中,而使用tf.train.start_queue_runners
之后,才会启动填充队列的线程,这时系统就不再“停滞”。此后计算单元就可以拿到数据并进行计算,整个程序也就跑起来了,这就是函数tf.train.start_queue_runners
的用处。为了管理线程,需要tf.train.Coordinator
来在合适的时候结束线程。
以下为这部分的代码:
# Initialize all global and local variables
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)
# Create a coordinator and run all QueueRunner objects
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for batch_index in range(5):
img, lbl = sess.run([images, labels])
img = img.astype(np.uint8)
for j in range(6):
plt.subplot(2, 3, j+1)
plt.imshow(img[j, ...])
plt.title('cat' if lbl[j]==0 else 'dog')
plt.show()
# Stop the threads
coord.request_stop()
# Wait for threads to stop
coord.join(threads)
sess.close()