TFRecord
是TensorFlow
官方推荐使用的数据格式化存储工具,它不仅规范了数据的读写方式,还大大地提高了IO效率。
TFRecord
内部使用了Protocol Buffer
二进制数据编码方案,只要生成一次TFRecord
,之后的数据读取和加工处理的效率都会得到提高。
而且,使用TFRecord
可以直接作为Cloud ML Engine
的输入数据。
一般来说,我们使用TensorFlow
进行数据读取的方式有以下4种:
Python
代码读取一部分数据,然后使用feed_dict
输入到计算图Threading
和Queues
从TFRecord
中分批次读取数据Dataset API
(1)方案对于数据量不大的场景来说是足够简单而高效的,但是随着数据量的增长,势必会对有限的内存空间带来极大的压力,还有长时间的数据预加载,甚至导致我们十分熟悉的OutOfMemoryError
;
(2)方案可以一定程度上缓解了方案(1)的内存压力问题,但是由于在单线程环境下我们的IO操作一般都是同步阻塞的,势必会在一定程度上导致学习时间的增加,尤其是相同的数据需要重复多次读取的情况下;
而方案(3)和方案(4)都利用了我们的TFRecord
,由于使用了多线程使得IO操作不再阻塞我们的模型训练,同时为了实现线程间的数据传输引入了Queues
。
下面,我们以Fashion MNIST
数据集为例,介绍生成TFRecrd
的方法。
所谓的Fashion MNIST
数据集,其实就是大小为28*28
的共10
个类别的服装图像:
下面我们把数据集下载并保存到data/fashion
目录下:
$ mkdir -p data/fashin
$ cd data/fashion
$ wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
$ wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
$ wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
$ wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
$ cd ../..
然后,我们在TensorFlow
使用和MNIST
数据集相同的代码进行数据读取:
from tensorflow.examples.tutorials.mnist import input_data
fashion_mnist = input_data.read_data_sets('data/fashion')
使用TFRecord
时,一般以tf.train.Example
和tf.train.SequenceExample
作为基本单位来进行数据读取。
tf.train.Example
一般用于数值、图像等有固定大小的数据,同时使用tf.train.Feature
指定每个记录各特征的名称和数据类型,用法如下:
tf.train.Example(features=tf.train.Features(feature={
'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
'width' : tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
'depth' : tf.train.Feature(int64_list=tf.train.Int64List(value=[depth])),
'image' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))
}))
tf.train.SequenceExample
一般用于文本、时间序列等没有固定长度大小的数据,用法如下:
example = tf.train.SequenceExample()
# 通过context来指定数据量的大小
example.context.feature["length"].int64_list.value.append(len(data))
# 通过feature_lists来加载数据
words_list = example.feature_lists.feature_list["words"]
for word in words:
words_list.feature.add().int64_list.value.append(word_id(word))
接下来,让我们把原始的Fashion MNIST
数据集转化为TFRecord
并保存下来:
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
def make_example(image, label):
return tf.train.Example(features=tf.train.Features(feature={
'image' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
'label' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[label]))
}))
def write_tfrecord(images, labels, filename):
writer = tf.python_io.TFRecordWriter(filename)
for image, label in zip(images, labels):
labels = labels.astype(np.float32)
ex = make_example(image.tobytes(), label.tobytes())
writer.write(ex.SerializeToString())
writer.close()
def main():
fashion_mnist = input_data.read_data_sets('data/fashion', one_hot=True)
train_images = fashion_mnist.train.images
train_labels = fashion_mnist.train.labels
test_images = fashion_mnist.test.images
test_labels = fashion_mnist.test.labels
write_tfrecord(train_images, train_labels, 'fashion_mnist_train.tfrecord')
write_tfrecord(test_images, test_labels, 'fashion_mnist_test.tfrecord')
if __name__ == '__main__':
main()
执行了上面的代码后,会在当前工作目录下生成两个TFRecord
数据文件——fashion_mnist_train.tfrecord
和fashion_mnist_test.tfrecord
。
如果我们想确认下刚才生成的TFRecord
是否合乎我们的预期,tf.train.Example.FromString
应该是不二之选了。
In [1]: import tensorflow as tf
In [2]: example = next(tf.python_io.tf_record_iterator("fashion_mnist_train.tfrecord"))
In [3]: tf.train.Example.FromString(example)
Out[3]:
features {
feature {
feature {
key: "image"
value {
bytes_list {
value: "\000...\000"
}
}
}
feature {
key: "label"
value {
bytes_list {
value: "\000...\000"
}
}
}
}
由此可知,features
包含了image
、label
、height
、width
等特征。
为了完成这项任务,推荐使用tf.parse_single_example
:
def read_tfrecord(filename):
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.string)
})
image = tf.decode_raw(features['image'], tf.float32)
label = tf.decode_raw(features['label'], tf.float64)
image = tf.reshape(image, [28, 28, 1])
label = tf.reshape(label, [10])
image, label = tf.train.batch([image, label],
batch_size=16,
capacity=500)
return image, label
下面让我们把TFRecord
使用到真实的模型训练场景中,虽然这次的Fashion MNIST
数据量并不算大,完全可以一次性全部加载到内存中,但我们的TFRecord
一样有用武之地,就是实现异步IO。
import numpy as np
import tensorflow as tf
import tfrecord_io
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib import slim
def model(image, label):
net = slim.conv2d(image, 48, [5,5], scope='conv1')
net = slim.max_pool2d(net, [2,2], scope='pool1')
net = slim.conv2d(net, 96, [5,5], scope='conv2')
net = slim.max_pool2d(net, [2,2], scope='pool2')
net = slim.flatten(net, scope='flatten')
net = slim.fully_connected(net, 512, scope='fully_connected1')
logits = slim.fully_connected(net, 10,
activation_fn=None, scope='fully_connected2')
prob = slim.softmax(logits)
loss = slim.losses.softmax_cross_entropy(logits, label)
train_op = slim.optimize_loss(loss, slim.get_global_step(),
learning_rate=0.001,
optimizer='Adam')
return train_op
def main():
train_images, train_labels = tfrecord_io.read_tfrecord('fashion_mnist_train.tfrecord')
train_op = model(train_images, train_labels)
step = 0
with tf.Session() as sess:
init_op = tf.group(
tf.local_variables_initializer(),
tf.global_variables_initializer())
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
while step < 3000:
sess.run([train_op])
if step % 100 == 0:
print('step: {}'.format(step))
step += 1
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
main()