Tensorflow 使用笔记:TFRecords
Tensorflow 的数据输入现在主要有两种形式:直接使用 Python 和 TFRecords . 在图像的项目中看到比较多的是直接自己实现dataprovider ,在 NLP 项目中见到比较多先做生成TFRecords 然后利用 tf.data.TFRecordDataset 来读取。我习惯 TRFRecords 的方式来实现。主要因为可以把数据清洗和模型处理的过程分开,二者不是混杂在一起。TFRecords 作为中间格式存在,生成什么样的 TFRecord 完全决定于你对要做的问题的理解,因为这里定义了你将要用到的特征。
我们通常有图像或文本这样的原始数据,拿图像分类或文本分类任务来说。我们的输入特征可能是图像的像素矩阵或者文本中词对应的 ID 而分类标签可能是对应标签的Id 或者甚至直接是字符串等等。Tensorflow 把这样的数据抽象成 Example 。 Example 有很多 Feature 这些 Feature 的数据类型主要有三种。TFrecord 中存储的就是 Example 对象对应的二进制数据,确切的说是使用 protobuf 序列化的二进制数据。在读取的使用 Tensorflow 提供的 DataSet API 在对序列化的数据解码的时候可以把想用的特征解码成对应的 Tensor 。简单的抽象和实现流程如下。
Python学习交流群:1004391443
Example
在创建 TFRecords 的过程中需要对Example 的定义比较好的理解。 数据类型抽象成三种:bytes, float, int64 , Feature 的基本组成单元是这三种数据的 list 定义如下:
message BytesList { repeated bytes value = 1; } message FloatList { repeated float value = 1 [packed = true]; } message Int64List { repeated int64 value = 1 [packed = true]; }
Feature 就是 BytesList, FloatList,Int64List 的封装
message Feature { // Each feature can be exactly one kind. oneof kind { BytesList bytes_list = 1; FloatList float_list = 2; Int64List int64_list = 3; } };
Feature 可以组成Map 状态的 Features :
message Features { // Map from feature name to feature. mapfeature = 1; };
还可以组成 FeatureList
message FeatureList { repeated Feature feature = 1; };
二者结合还可以产生下面类型
message FeatureLists { // Map from feature name to feature list. mapfeature_list = 1; };
如果对protobuf 的语法有了解的话,这些定义就很明了了。
Exmaple 的是 map 型的 Feature 的组合
message Example { Features features = 1; };
序列状态的 Example
message SequenceExample { Features context = 1; FeatureLists feature_lists = 2; };
了解这些定义之后,我们要做的就是把各种原始数据转成 bytes ,float ,int 类型然后构造成 Feature 然后组成 Example 序列到 文件中就好了 下面是一个完整的例子把 mnist 的数据序列化到 TFRecords
#!/usr/bin/env python #-*- coding:utf-8 -*- #author: wu.zheng midday.me import mnist import cv2 import os import sys import numpy as np import tensorflow as tf def _bytes_feature(value): if not isinstance(value, list): value = [value] return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) def write_tfrecord(data, labels, out_data_path): writer = tf.python_io.TFRecordWriter(out_data_path) counter = 0 total_count = len(data) for image, label in zip(data, labels): counter += 1 image = np.array(image) image = image.reshape((28, 28)) is_success, image_buffer = cv2.imencode(".jpg", image) if not is_success: continue label_value = [0] * 10 label_value[label] = 1 image_feature = _bytes_feature(image_buffer.tostring()); label_feature = _int64_feature(label_value) features = tf.train.Features(feature={"image":image_feature, "label":label_feature}) example = tf.train.Example(features=features) writer.write(example.SerializeToString()) sys.stdout.write("\r>>Writing to {:s} {:d}/{:d}".format(out_data_path, counter, total_count)) sys.stdout.flush() writer.close() sys.stdout.write("\n") sys.stdout.write(">>{:s} write finish. ".format(out_data_path)) def create_mnist_tfrecord(in_data_floder, out_data_floder ): meta_data = mnist.MNIST(in_data_floder) train_data, train_labels = meta_data.load_training() test_data, test_labels = meta_data.load_testing() train_tf_record_path = os.path.join(out_data_floder, 'train_mnist.tfrecord') test_tf_record_path = os.path.join(out_data_floder, 'test_mnist.tfrecord') write_tfrecord(train_data, train_labels, train_tf_record_path) write_tfrecord(test_data, test_labels, test_tf_record_path) if __name__ == "__main__": # datasets/mnist 下存放的是解压后的 mnist 数据, in_data_floder = "./datasets/mnist" out_data_floder = "./datasets/mnist_tfrecord" create_mnist_tfrecord(in_data_floder, out_data_floder)
上门例子还有些需要改进的地方,通常这个构建过程相对比较慢,几百万的数据可能会花费一两天的时间,所以需要多线程处理,生成一个 TFRecords 文件可能会很大,也不方便分布式,通常会把生成的文件划分成很多份。
Input_fn
有了 TFRecords 我们可以实现一个 input_fn 就好了,如果后面我们有新的数据要添加进来继续训练我们的模型,也只需要按照上门的步骤处理成 TFRecords, input_fn 不用做改变。在 input_fn 里面我们可以做数据增强等一些处理
在这里有个比较麻烦的是 Example 中定义的 Feature 会有与之对应的 tf.data.Feature. 有 VarLenFeature , SparseFeature , FixedLenFeature , FixedLenSequenceFeature 使用的是后选择合适的 Feature 就好了,他们本质上是对应这不通形态的 Tensor 比如 VarLenFeature 会产生一个 SparseTensor
下面是 mnist 数据的 一个 input_fn 的 实现:
#!/usr/bin/env python #-*- coding:utf-8 -*- #author: wu.zheng midday.me import tensorflow as tf def _decode_record(record_proto): feature_map = { "image": tf.FixedLenFeature((), tf.string), 'label': tf.VarLenFeature(tf.int64), } features = tf.parse_single_example(record_proto, features=feature_map) image = features['image'] image = tf.image.decode_jpeg(image, channels=1) image = tf.cast(image, tf.float32) paddings = tf.constant([[2, 2], [2, 2], [0,0]]) image = tf.pad(image, paddings, mode='CONSTANT', constant_values=0 ) image = image / 255.0 label = features['label'] example = {"image": image, "label": label} return example def input_fn(tfrecord_path, batch_size, is_training): dataset = tf.data.TFRecordDataset(tfrecord_path) if is_training: dataset = dataset.repeat().shuffle(buffer_size=10000) else: dataset = tf.repeat(1) dataset = dataset.map(lambda x: _decode_record(x)) dataset = dataset.batch(batch_size=batch_size) return dataset.make_one_shot_iterator() if __name__ == "__main__": tf_record_path = "./datasets/mnist_tfrecord/train_mnist.tfrecord" with tf.Session() as sess: iterator = input_fn(tf_record_path, 1, True) next_batch = iterator.get_next() sess.run(tf.global_variables_initializer()) while True: batch = sess.run(next_batch) image = batch['image'] print(image.shape) exit(0)