深入浅出的TensorFlow数据格式化存储工具TFRecord用法教程

TFRecordTensorFlow官方推荐使用的数据格式化存储工具,它不仅规范了数据的读写方式,还大大地提高了IO效率。

1.使用TFRecord的理由

TFRecord内部使用了Protocol Buffer二进制数据编码方案,只要生成一次TFRecord,之后的数据读取和加工处理的效率都会得到提高。

而且,使用TFRecord可以直接作为Cloud ML Engine的输入数据。

一般来说,我们使用TensorFlow进行数据读取的方式有以下4种:

  1. 预先把所有数据加载进内存
  2. 在每轮训练中使用原生Python代码读取一部分数据,然后使用feed_dict输入到计算图
  3. 利用ThreadingQueuesTFRecord中分批次读取数据
  4. 使用Dataset API

(1)方案对于数据量不大的场景来说是足够简单而高效的,但是随着数据量的增长,势必会对有限的内存空间带来极大的压力,还有长时间的数据预加载,甚至导致我们十分熟悉的OutOfMemoryError

(2)方案可以一定程度上缓解了方案(1)的内存压力问题,但是由于在单线程环境下我们的IO操作一般都是同步阻塞的,势必会在一定程度上导致学习时间的增加,尤其是相同的数据需要重复多次读取的情况下;

而方案(3)和方案(4)都利用了我们的TFRecord,由于使用了多线程使得IO操作不再阻塞我们的模型训练,同时为了实现线程间的数据传输引入了Queues

2.准备数据

下面,我们以Fashion MNIST数据集为例,介绍生成TFRecrd的方法。

所谓的Fashion MNIST数据集,其实就是大小为28*28的共10个类别的服装图像:

fashion-mnist

下面我们把数据集下载并保存到data/fashion目录下:

$ mkdir -p data/fashin
$ cd data/fashion
$ wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
$ wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
$ wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
$ wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
$ cd ../..

然后,我们在TensorFlow使用和MNIST数据集相同的代码进行数据读取:

from tensorflow.examples.tutorials.mnist import input_data

fashion_mnist = input_data.read_data_sets('data/fashion')

3.Example记录和SequenceExample记录

使用TFRecord时,一般以tf.train.Exampletf.train.SequenceExample作为基本单位来进行数据读取。

tf.train.Example一般用于数值、图像等有固定大小的数据,同时使用tf.train.Feature指定每个记录各特征的名称和数据类型,用法如下:

tf.train.Example(features=tf.train.Features(feature={
    'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
    'width' : tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
    'depth' : tf.train.Feature(int64_list=tf.train.Int64List(value=[depth])),
    'image' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))
}))

tf.train.SequenceExample一般用于文本、时间序列等没有固定长度大小的数据,用法如下:

example = tf.train.SequenceExample()
# 通过context来指定数据量的大小
example.context.feature["length"].int64_list.value.append(len(data))

# 通过feature_lists来加载数据
words_list = example.feature_lists.feature_list["words"]
for word in words:
    words_list.feature.add().int64_list.value.append(word_id(word))

4.生成TFRecord

接下来,让我们把原始的Fashion MNIST数据集转化为TFRecord并保存下来:

import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

def make_example(image, label):
    return tf.train.Example(features=tf.train.Features(feature={
        'image' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
        'label' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[label]))
    }))

def write_tfrecord(images, labels, filename):
    writer = tf.python_io.TFRecordWriter(filename)
    for image, label in zip(images, labels):
        labels = labels.astype(np.float32)
        ex = make_example(image.tobytes(), label.tobytes())
        writer.write(ex.SerializeToString())
    writer.close()

def main():
    fashion_mnist = input_data.read_data_sets('data/fashion', one_hot=True)
    train_images  = fashion_mnist.train.images
    train_labels  = fashion_mnist.train.labels
    test_images   = fashion_mnist.test.images
    test_labels   = fashion_mnist.test.labels
    write_tfrecord(train_images, train_labels, 'fashion_mnist_train.tfrecord')
    write_tfrecord(test_images, test_labels, 'fashion_mnist_test.tfrecord')

if __name__ == '__main__':
    main()

执行了上面的代码后,会在当前工作目录下生成两个TFRecord数据文件——fashion_mnist_train.tfrecordfashion_mnist_test.tfrecord

5.确认TFRecord的内容

如果我们想确认下刚才生成的TFRecord是否合乎我们的预期,tf.train.Example.FromString应该是不二之选了。

In [1]: import tensorflow as tf

In [2]: example = next(tf.python_io.tf_record_iterator("fashion_mnist_train.tfrecord"))

In [3]: tf.train.Example.FromString(example)
Out[3]:
features {
  feature {
  feature {
    key: "image"
    value {
      bytes_list {
        value: "\000...\000"
      }
    }
  }
  feature {
    key: "label"
    value {
      bytes_list {
        value: "\000...\000"
      }
    }
  }
}

由此可知,features包含了imagelabelheightwidth等特征。

6.读取TFRecord

为了完成这项任务,推荐使用tf.parse_single_example

def read_tfrecord(filename):
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    features = tf.parse_single_example(
        serialized_example,
        features={
            'image': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.string)
        })

    image = tf.decode_raw(features['image'], tf.float32)
    label = tf.decode_raw(features['label'], tf.float64)

    image = tf.reshape(image, [28, 28, 1])
    label = tf.reshape(label, [10])

    image, label = tf.train.batch([image, label],
            batch_size=16,
            capacity=500)

    return image, label

7.整合

下面让我们把TFRecord使用到真实的模型训练场景中,虽然这次的Fashion MNIST数据量并不算大,完全可以一次性全部加载到内存中,但我们的TFRecord一样有用武之地,就是实现异步IO。

import numpy as np
import tensorflow as tf
import tfrecord_io
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib import slim

def model(image, label):
    net = slim.conv2d(image, 48, [5,5], scope='conv1')
    net = slim.max_pool2d(net, [2,2], scope='pool1')
    net = slim.conv2d(net, 96, [5,5], scope='conv2')
    net = slim.max_pool2d(net, [2,2], scope='pool2')
    net = slim.flatten(net, scope='flatten')
    net = slim.fully_connected(net, 512, scope='fully_connected1')
    logits = slim.fully_connected(net, 10,
            activation_fn=None, scope='fully_connected2')

    prob = slim.softmax(logits)
    loss = slim.losses.softmax_cross_entropy(logits, label)

    train_op = slim.optimize_loss(loss, slim.get_global_step(),
            learning_rate=0.001,
            optimizer='Adam')

    return train_op

def main():
    train_images, train_labels = tfrecord_io.read_tfrecord('fashion_mnist_train.tfrecord')
    train_op = model(train_images, train_labels)

    step = 0
    with tf.Session() as sess:
        init_op = tf.group(
            tf.local_variables_initializer(),
            tf.global_variables_initializer())
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        while step < 3000:
            sess.run([train_op])

            if step % 100 == 0:
                print('step: {}'.format(step))

            step += 1

        coord.request_stop()
        coord.join(threads)

if __name__ == '__main__':
    main()

你可能感兴趣的:(算法)