[Tensorflow]关于TFRecord和tf.Example的使用

 为了高效地读取数据,可以将数据进行序列化存储,这样也便于网络流式读取数据。TFRecord是一种比较常用的存储二进制序列数据的方法,基于Google的Protocol buffers格式的数据。

tf.Example类是一种将数据表示为{"string": value}形式的meassage类型,Tensorflow经常使用tf.Example来写入、读取TFRecord数据

1. 关于tf.Example

1.1 tf.Example的数据类型

 一般来说,tf.Example都是{"string": tf.train.Feature}这样的键值映射形式。其中,tf.train.Feature类可以使用以下3种类型

  • tf.train.BytesList: 可以使用的类型包括 stringbyte

  • tf.train.FloatList: 可以使用的类型包括 floatdouble

  • tf.train.Int64List: 可以使用的类型包括 enum,bool, int32, uint32, int64以及uint64

 为了将常用的数据类型(标量或list),转化为tf.Example兼容的tf.train.Feature类型,通过使用以下几个接口函数:

# 这里括号中的value是一个标量
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
    
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

为了示例的简洁,这里只是使用了变量。如果想要对张量进行处理,常用的方法是:使用tf.serialize_tensor函数将张量转化为二进制字符,然后使用_bytes_feature()进行处理;读取的时候使用tf.parse_tensor对二进制字符转换为Tensor类型。

 下面举一个简单例子了解一下,经过tf.train.Feature转换之后的结果

print(_bytes_feature(b'test_string'))

## 输出为: 
## bytes_list {
##   value: "test_string"    
## }

 所有的proto meassages都可以通过.SerializeToString方法转换为二进制字符串

feature = _float_feature(np.exp(1))
feature.SerializeToString()

# 输出为:b'\x12\x06\n\x04T\xf8-@'

1.2 创建一个tf.Example数据

 无论是什么类型,基于已有的数据构造tf.Example数据的流程是相同的:

  • 对于一个观测值,需要转化为上面所说的tf.train.Feature兼容的3种类型之一;

  • 构造一个字典映射,键key是string型的feature名称,值value
    是第1步中转换得到的值;

  • 第2步中得到的映射会被转换为Features类型的数据

假设存在一个数据集,包含4个特征:1个bool型,1个int型,1个string型以及1个float型;假设数据集的数量为10000

n_obeservations = int(1e4)

# bool类型的特征
feature0 = np.random.choice([False, True], n_observations)

# int型特征
feature1 = np.random.randint(0, 5, n_observations)

# string型特征
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]

# float型特征
feature3 = np.random.randn(n_observations)

 定义一个将各种类型封装的序列化函数

def serialize(feature0, feature1, feature2, feature3):
    feature = {
        "feature0": _int64_feature(feature0),
        "feature1": _int64_feature(feature1),
        "feature2": _bytes_feature(feature2),
        "feature3": _float_feature(feature3),
    }
    
    # 使用tf.train.Example创建Features的message
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

## Example
serialized_example = serialize_example(False, 4, b'goat', 0.9876)

## 使用tf.train.Example.FromString可视化结果
example_proto = tf.train.Example.FromString(serialized_example)

2. 使用tf.data读写TFRecord

2.1 写入TFRecord文件

 最简单的将数据读入dataset中的方法,就是使用tf.data.Dataset.from_tensor_slices函数

  • 只应用到一个array则返回一个标量dataset
tf.data.Dataset.from_tensor_slices(feature1)
  • 应用到一个array组成的元组,则返回一个元组dataset
features_dataset = tf.data.Dataset.from_tensor_slices((feature0, feature1, feature2, feature3))

## 如果想要从数据集中取一个样例
for f0, f1, f2, f3 in features_dataset.take(1):
    print(f0)
    break

## 输出: tf.Tensor(True, shape=[], dtype=bool),注意这里输出的是一个Tensor类型的元素

 使用tf.data.Dataset.map方法对Dataset中的每一个元素使用相同的方法。需要注意的是:×该函数只能在graph模式下运行,也就是说必须在graph中定义并且返回tf.Tensors类型。上面定义的serialize_example并不是返回tensor的函数,需要使用tf.py_function函数包装使其可以兼容,于是将上面定义的函数进行改进。使用tf.py_dunction需要制定shapetype

def serialize_pyfunction(feature0, feature1, feature2, feature3):

    # 由于上面输出是Tensor类型,所以要使用.numpy()获取其中的值
    feature = {
        "feature0": _int64_feature(feature0.numpy()), 
        "feature1": _int64_feature(feature1.numpy()),
        "feature2": _bytes_feature(feature2.numpy()),
        "feature3": _float_feature(feature3.numpy()),
    }
    
    # 使用tf.train.Example创建Features的message
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

# 定义包装函数
def tf_serialize_example(f0 ,f1, f2, f3):
    tf_string = tf.py_function(
        serialize_pyfunction,
        (f0,f1,f2,f3),
        tf.string)
    
    return tf.reshape(tf_string, ())  # 结果是一个标量

# 对dataset中的每一个元素应用下面函数
serialized_features_dataset = features_dataset.map(tf_serialize_example)

将上面的结果写入TFRecord

filename = "test.tfrecord"
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_features_dataset)

但是由于上面这种方法较为复杂,我们通常使用新的方法写入TFRecord,即使用tf.python_io模块中读写TFRecord文件的类

 使用新的方法,就不需要再重新定义包装函数,而是可以对数据逐个写入,即:

with tf.python_io.TFRecordWriter(filename) as writer:
    for i in range(n_observations):
        example = serialize(feature0[i], feature1[i], feature2[i], feature3[i])
        writer.write(example)

2.2 读取TFRecord文件

 这里介绍使用最简单的tf.data.TFRecordDataset方法读取数据,这个方法可以将一个或多个TFRecord文件的内容作为输入管道的一部分进行流式传输。

filenames = ["file1.tfrecord", "file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)

TFRecordDataset的初始化filenames参数可以是字符串、字符串列表或者字符串tf.Tensor;如果有两组分别用于训练和验证的文件,可以使用tf.placeholder(tf.string)来表示文件名,并使用适当的文件初始化迭代器

filenames = tf.placehoder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map()   # 将记录解码为Tensor
dataset = dataset.repeat()  # 无限重复输入
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()

# 在不同的阶段,使用不同的值
training_files = ["train1.tfrecord", "train2.tfrecord"]
sess.run([iterator.initializer], feed_dict={filenames: training_files})

validation_files = []
sess.run(iterator,initializer, feed_dict={filenames: validation_files})

3. 一个读写TFRecord的实例

读阶段

import numpy as np 
import tensorflow as tf 
improt glob 
import matplotlib.image as mpimg

# 写入阶段
def images_to_tfrecords(data_path="mnist/", shuffle=True, random_seed=None):
    def int64_to_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
    
    for s in ["train", "valid", "test"]:
        with tf.python_io.TFRecordWriter("mnist_%s.tfrecords"%s) as writer:
            img_paths = np.array([p for p in glob.iglob(
                    "%s%s/**/*.jpg" % (data_path, s), recursive=True)])
            
            if shuffle:
                rng = np.random.RandomState(random)
                rng.shuffle(img_paths)
            
            for idx, path in enumerate(img_paths):
                label = int(os.path.basename(os.path.dirname(path)))
                image = mpimg.imread(path)
                image = image.reshape(-1).tolist()
                
                example = tf.train.Example(features=tf.train.Features(
                    features={
                        "image": int64_to_feature(image),
                        "label": int64_to_feature([label])
                    }))
                writer.write(example.SerializeToString())

写阶段

def read_one_image(tfrecords_queue, normalize=True):
    reader = tf.TFRecordReader()
    key, value = reader.read(tfrecords_queue)
    features = tf.parse_single_example(value, 
        features={"label": tf.FixedLenFeature([], tf.int64),
            "image":tf.FixedLenFeature([784], tf.int64)
        })
    label = tf.cast(features['label'], tf.int32)
    image = tf.cast(features['image'], tf.float32)
    onehot_label = tf.one_hot(indices=label, depth=10)
    
    if normalize:
        image = image / 255
    return onehot_label, image

n_epochs = 15 
n_iter = n_epochs*(num_samples//batch_size)

g = tf.Graph()
with g.as_default():
    
    # 输入数据
    queue = tf.train.string_input_producer(["mnist_train.tfrecords"], num_epochs=None)
    label, image = read_one_image(queue)
    label_batch, image_batch = tf.train.shuffle_batch(
                [label, image], batch_size=batch_size,
                seed=random_seed, num_threads=8,
                capacity=5000,
                min_after_dequeue=2000)
    
    tf_images = tf.placeholder_with_default(image_batch,
                    shape=[None, 784], name="images")
    tf_labels = tf.placeholder_with_default(label_batch,
                    shape=[None, 10], name="labels")
    [...]
    

with tf.Session(graph=g) as sess:
    sess.run(tf.global_variables_initializer())
    saver0 = tf.train.Saver()
    
    ## 创建一个线程管理器
    coord = tf.train.Coordinator()
    ## 启动入队线程
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    iter_per_epoch = n_iter // n_epochs
    
    for i in range(n_iter):
        [...]
        if not i % iter_per_epoch:
    
    ## 发出终止所有线程的命令
    coord.request_stop()
    ## 把线程加入主线程,等待threads结束
    coord.join(threads)
    
    [...]

使用数据进行测试

record_iterator = tf.python_io.tf_record_iterator(path="mnist_test.tfrecords")

with tf.Session() as sess:
    ## 重建保存在meta文件中的graph
    saver1 = tf.train.import_meta_graph("./**.meta")
    saver1.restore(sess, save_path="./mlp")
    
    for idx, r in enumerate(record_iterator):
        example = tf.train.Example()
        example.ParseFromString(r)
        label = example.features.feature['label'].int64_list.value[0]
        image = example = np.array(example.features.feature['image'].int64_list.value)
        
        pred = sess.run("prediction:0", feed_dict={"images:0": images.reshape(1,784)})
        [...]

参考链接

  • Tensorflow官方tfrecords使用教程
  • Tensorflow官方数据导入教程
  • DeepLearing Models

你可能感兴趣的:(深度学习)