深入浅出tfrecord数据格式的保存与读取,同时适用于tf1与tf2

问题

当我们在保存海量小文件的数据的时候,是否碰到过这样的问题?

保存的文件数量太多,每次加载的时候,受python IO限制,读取文件速度很慢,即使是使用多进程,速度仍然无法满足要求

有同学可能会问,可以保存成矩阵形式保存,然而矩阵形式保存需要每个样本的维度是一致的,在不一样的维度,需要进行padding,有时候padding的部分会占用大量内存,因此对每个样本进行单独保存是最好的方式。tfrecord可以完美的解决我们的问题,这里介绍一下实战了2年的tfrecord简单使用方法,同时支持tf1与tf2

保存

首先我们需要定义好6种输入类型

import tensorflow as tf


def bytes_feature(value: bytes):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def bytes_list_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


def int64_feature(value: int):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def int64_list_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def float_feature(value: float):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def float_list_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

有了这6种输入类型,我们就可以进行tfrecord数据的写入了

writer = tf.io.TFRecordWriter("data.tfrecords") # 生成空的tfrecord文件
example = tf.train.Example(features=tf.train.Features(feature={
            'feature': bytes_feature(outputs.tobytes()),
            'feature_height': int64_feature(outputs.shape[0]),
            'feature_width': int64_feature(outputs.shape[1]),
            'label': bytes_feature(label.tobytes()),
        })) # 将array形式的feature和label变成tfrecord格式的样本
writer.write(example.SerializeToString()) # 对tfrecord样本进行序列化后保存
writer.close() # 关闭tfrecord文件,完成tfrecord文件的生成

在进行上述样例的复现的时候,可以将output定义为np.ones((1, 2048)),将label定义为np.zeros((1, 1))

读取

tfrecord数据的读取方式也是很方便的

dataset = tf.data.TFRecordDataset('data.tfrecord') #生成dataset
feature_map = {'feature': tf.io.FixedLenFeature((), tf.string),
                       'feature_height': tf.io.FixedLenFeature((), tf.int64),
                       'feature_width': tf.io.FixedLenFeature((), tf.int64),
                       'label': tf.io.FixedLenFeature((), tf.string),
                       } #定义feature map
for x in dataset:
	parsed_example = tf.io.parse_single_example(x, feature_map) #按照feature map对数据进行解析
	feature = tf.io.decode_raw(parsed_example["feature"], out_type=tf.float32)
	height = parsed_example["feature_height"]
    width = parsed_example["feature_width"]
    label = tf.io.decode_raw(parsed_example["label"])
    feature = tf.reshape(feature, [height, width]) # 将feature进行还原
    label = tf.reshape(label, [height, 1]) # 将label进行还原

在tf2中,上述解析后的数据都可以通过.numpy()的方式还原为numpy数组,对于tf1来说,我们需要用session的方式进行读取,如果有需要的话再进行更新。
至此我们便实现了tfrecord数据的保存与读取,上述例子均可以单独复现,无需其他操作。

你可能感兴趣的:(tensorflow,python,tensorflow,深度学习)