一、TFRecords的数据结构
TFRecords数据集是一种二进制的数据集,是tensorflow推荐的标准文件格式。Tensorflow通过ProtocolBuffers定义了TFRecords文件中存储的记录及其所含的字段结构,使用该方式可以将数据,标签以及和数据相关的信息通过key,value的形式存储在同一个文件中,并通过key,value的形式对存储的数据进行读取。该数据结构定义在tensrflow/core/example目录下的example.proto和feature.proto文件中,因此在构建实例时我们将转化后的张量称为样例,其内部记录称为特征域。
关于TFRecords的具体结构如下:
example = tf.train.Example(
features=tf.train.Features(
feature={
}))
其中example就是 样例,其中包含一个Features类型的数结构其命名为features,一个Features类型的数据结构又包含一个feature,feature中是多个key,value结构的数据,key是一个字符型数据,value是一个Feature型数据。
message Features {
map
}
message Feature{
one of kind {
ByteList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3 ;
}
}
二、TFRecords的写入方式(读入图片数据为例,存储对应的图片内容,宽,高,标签)
写入TFRecords数据时需主要分三步:
1. 定义对应的数据结构:
def image_to_tfexample(image_data, image_format, height, width, class_id):
return tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data])),
'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_format])),
'image/class/label': tf.train.Feature(int64_list=tf.train.Int64List(value=class_id)),
'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
}))
2. 定义数据的读取方式:该步骤主要是将原始的数据读入并转化成相应的Bytes或者Int64的格式放入到对应的结构中,本例中读入的是图片,所以读取的方式如下:
for i in range(len(path)):
image_data = tf.gfile.FastGFile(path[i], 'rb').read()
image = tf.image.decode_jpeg(image_data)
image = sess.run(image)
print(image.shape)
height = image.shape[0]
width = image.shape[1]
class_id = [i]
3. 定义TFRecords生成的名字、写入和关闭文件:在该部分首先要在循环外面使用
writer = tf.python_io.TFRecordWriter('image_test.tfrecord')定义一个输出的文件名。
接着在循环中使用下面的语句将数据写入:
example = image_to_tfexample(image_data, b'jpg', height, width, class_id)
writer.write(example.SerializeToString())
最后使用writer.close()关闭文件。
通过上述运行之后会在目录中出现一个名为image_test.tfrecord的二进制文件,该文件中存储了所有相关的信息。
三、 TFRecords数据的读取。
关于该数据在读取时主要分为两部分:
1. 构造读取结构:该结构是
keys_to_features = {
'image/encoded': tf.FixedLenFeature([], tf.string, default_value=''),
'image/format': tf.FixedLenFeature([], tf.string, default_value=''),
'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
'image/height': tf.FixedLenFeature([], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
'image/width': tf.FixedLenFeature([], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
}
2. 构建读取器和解析器(将样例转换为张量)
filename_queue = tf.train.string_input_producer(['image_test.tfrecord'], num_epochs=2)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# 将一条序列化的样例转化为其包含所有特征的向量
features = tf.parse_single_example(
serialized_example,
features=keys_to_features
)
其中TFRecordReader()表示用来读取tfrecord格式的数据,其对应的read方法要求传入的是一个queue。得到对应的样例, tf.parse_single_example的作用是将输入的一个样例按照传入的features字典的形式转化成对应的张量。
至此就完成了对TFRecords数据的存储和读取,由于读取出的数据是tensor,因此要使用sess.run()的方式对数据进行显示,同样也可以对读取的数据进行和正常的数据同样的操作。
四、读取数据的操作
example = sess.run(features)
image1 = tf.image.decode_jpeg(example['image/encoded'], channels=3)
print(image1.shape)
此时image1的形状为[ ?, ?,3]
image2 = tf.image.resize_images(image1, [160, 160], method=1)
print(image2.shape)
此时image2的形状是 [160, 160, 3]
height = example['image/height']
print(height)
image = sess.run(image1)
print(image.shape)
此时image1的形状是 [250, 196, 3] 原始的图像大小。
# image = sess.run(image2)
# print(image2.shape)
print(example)
path = "image1/" + str(height) + str(i) + ".jpg"
misc.imsave(path, image)
对图片进行存储。
我们运行eaxmple时可以发现其结果是我们定义的
由于read方法读入的是一个队列,因此关于如何使用队列和线程进行数据的读取可以参考https://blog.csdn.net/hh_2018/article/details/81143109