如果你是 Tensorflow 的初学者,那么你或多或少在网络上别人的博客上见到过 TFRecord 的影子,但很多作者都没有很仔细地对它进行说明,这也许会让你感受到了苦恼。本文按照我自己的思路对此进行一番讲解,也许能够提供给你一些帮助。
TFRecord 是谷歌推荐的一种二进制文件格式,理论上它可以保存任何格式的信息。
uint64 length
uint32 masked_crc32_of_length
byte data[length]
uint32 masked_crc32_of_data
上面是 Tensorflow 的官网给出的文档结构。整个文件由文件长度信息、长度校验码、数据、数据校验码组成。
但对于我们普通开发者而言,我们并不需要关心这些,Tensorflow 提供了丰富的 API 可以帮助我们轻松读写 TFRecord 文件。
TFRecord 的核心内容在于内部有一系列的 Example ,Example 是 protocolbuf 协议下的消息体。
在这里我相信大家都对 protocolbuf 比较了解,如果不了解也没有关系,它本质上和 xml 及 json 没有多大的区别。
网上有很多 example 的简单说明。
message Example {
Features features = 1;
};
message Features {
// Map from feature name to feature.
map<string, Feature> feature = 1;
};
message Feature {
// Each feature can be exactly one kind.
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
message BytesList {
repeated bytes value = 1;
}
message FloatList {
repeated float value = 1 [packed = true];
}
message Int64List {
repeated int64 value = 1 [packed = true];
}
熟悉 protobuf 同学看到这个格式定义就能马上明白了,不熟悉的同学可以点击相关的文章,我之前的这篇有对 protocolbuf 作过详细解释。
一个 Example 消息体包含了一系列的 feature 属性。
每一个 feature 是一个 map,也就是 key-value 的键值对。
key 取值是 String 类型。
而 value 是 Feature 类型的消息体,它的取值有 3 种:
需要注意的是,他们都是列表的形式。
protocolbuf 是通用的协议格式,对主流的编程语言都适用。所以这些 List 对应到 python 语言当中是 列表,而对于 Java 或者 C/C++ 来说他们就是数组。
举个例子,一个 BytesList 可以存储 Byte 数组,因此像字符串、图片、视频等等都可以容纳进去。
所以 TFRecord 可以存储几乎任何格式的信息。
但需要说明的是,更官方的文档来源于 Tensorflow的源码,这里面有详细的定义及注释说明。
TFRecord 也不是非用不可,但它确实是谷歌官方推荐的文件格式。
1、它特别适应于 Tensorflow ,或者说它就是为 Tensorflow 量身打造的。
2、因为 Tensorflow开发者众多,统一训练时数据的文件格式是一件很有意义的事情。也有助于降低学习成本和迁移成本。
TFRecord 是一种文件格式,那么对于 TFRecord 文件的 IO 怎么处理呢?
事实上,Tensorflow 给我们提供了丰富的 API ,开发者运用这些 API 可以轻松地处理 TFRecord 文件。
我们可以利用 TFWriter
轻松完成这个任务。
但制作之前,我们要先明确自己的目的。
我们必须想清楚,要把什么信息存储到 TFRecord 文件当中,这其实是最重要的。
下面,举例说明。
因为深度学习很多都是与图片集打交道,那么,我们可以尝试下把一张张的图片转换成 TFRecord 文件。
首先定义 Example 消息体。
Example Message {
Features{
feature{
key:"name"
value:{
bytes_list:{
value:"cat"
}
}
}
feature{
key:"shape"
value:{
int64_list:{
value:689
value:720
value:3
}
}
}
feature{
key:"data"
value:{
bytes_list:{
value:0xbe
value:0xb2
...
value:0x3
}
}
}
}
}
上面的 Example 表示,要将一张 cat 图片信息写进 TFRecord 当中,而图片信息包含了图片的名字,图片的维度信息还有图片的数据,分别对应了 name、shape、content 3 个 feature。
下面,我们开始用代码实现它。
def write_test(input,output):
''' 借助于 TFRecordWriter 才能将信息写进 TFRecord 文件'''
writer = tf.python_io.TFRecordWriter(output)
# 读取图片并进行解码
image = tf.read_file(input)
image = tf.image.decode_jpeg(image)
with tf.Session() as sess:
image = sess.run(image)
shape = image.shape
# 将图片转换成 string。
image_data = image.tostring()
print(type(image))
print(len(image_data))
name = bytes("cat", encoding='utf8')
print(type(name))
# 创建 Example 对象,并且将 Feature 一一对应填充进去。
example = tf.train.Example(features=tf.train.Features(feature={
'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),
'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))
}
))
# 将 example 序列化成 string 类型,然后写入。
writer.write(example.SerializeToString())
writer.close()
write_test('cat.jpg','cat.tfrecord')
运行上面的代码,就可以在当前目录生成 cat.tfrecord 文件。
上面代码注释都比较详细,我挑重点来讲。
上一节是讲如何将一张图片的信息写入到一个 tfrecord 文件当中。
现在,我们需要检验它是否正确,这就需要用到如何读取 TFRecord 文件的知识点了。
def _parse_record(example_proto):
features = {
'name': tf.FixedLenFeature((), tf.string),
'shape': tf.FixedLenFeature([3], tf.int64),
'data': tf.FixedLenFeature((), tf.string)}
parsed_features = tf.parse_single_example(example_proto, features=features)
return parsed_features
def read_test(input_file):
# 用 dataset 读取 tfrecord 文件
dataset = tf.data.TFRecordDataset(input_file)
dataset = dataset.map(_parse_record)
iterator = dataset.make_one_shot_iterator()
with tf.Session() as sess:
features = sess.run(iterator.get_next())
name = features['name']
name = name.decode()
img_data = features['data']
shape = features['shape']
print('=======')
print(type(shape))
print(len(img_data))
# 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组
img_data = np.fromstring(img_data,dtype=np.uint8)
image_data = np.reshape(img_data,shape)
plt.figure()
#显示图片
plt.imshow(image_data)
plt.show()
#将数据重新编码成 jpg 图片并保存
img = tf.image.encode_jpeg(image_data)
tf.gfile.GFile('cat_encode.jpg','wb').write(img.eval())
read_test('cat.tfrecord')
代码比较简单,我也有给详细的注释,我挑重要的几点讲解一下。
- 我用 dataset 去读取 tfrecord 文件
- 在解析 example 的时候,用现成的 API 就好了 tf.parse_single_example
- 用 np.fromstring() 方法就可以获取解析后的 string 数据,记得数据格式还原成 np.uint8
- 用 tf.image.encode_jpeg() 方法可以将图片数据编码成 jpeg 格式。
- 用 tf.gfile.GFile 对象可以将图片数据保存到本地。
- 因为将图片 shape 写进了 example 中,解析的时候必须制定维度,在这里是 [3] ,不然程序报错。
运行程序后,可以看到图片显示正常.
并且将 TFRecord 中的图片数据也成功地保存到本地了。
Q:我的示例为什么用 Dataset 而不用大多数博文中的 QueueRunner 呢?
A:这是因为 Dataset 比 QueueRunner 新,而且是官方推荐的,Dataset 比较简单。
Q:学习了 TFRecord 相关知识,下一步学习什么?
A:可以尝试将常见的数据集如 MNIST 和 CIFAR-10 转换成 TFRecord 格式。
下一篇博文,我就将怎么将 CIFAR-10 转换成 TFRecord 格式人数据集,然后构建简单的神经网络去实验它。