学习视频链接:TFRecord生成和读取-深度学习框架应用开发-TensorFlow 2.0 | 百科荣创在线学习平台
TFRecord格式介绍
正常读取数据集是从硬盘直接读取数据,这样需要先将数据读取出来再进行训练,这意味着需要通过IO对硬盘上的数据再次进行读取,再把数据放入内存中,之后再送入神经网络进行运算,由于读取数据需要等待时间,这样就造成了大部分资源的浪费,导致训练时间过长,基于此,tensorflow官方推荐了TFRecord读取数据的方法。
TFRecord是一种tensorflow的标准文件格式,实质是二进制文件,遵循protocol buffer协议。TFRecord文件方便复制移动,能够很好的利用内存,无需单独标记文件,适用于大量数据的顺序读取。
TFRecord先把文件读取出来,把每一个文件堆成队列,把文件中的数据按照队列写入TFRecord文件中,这样读取数据就相当于读取内存,加快了数据读入的速度。
TFRecord格式使用
1.指定原始数据的文件列表
2.创建文件列表队列
3.从文件读取数据
4.整理成Batch作为神经网络输入
TFRecord文件生成流程
1.获取每个文件的数据和标签
2.将数据和标签转换为特征
3.将特征写入TFRecord文件
创建TFRecord
TFRecord支持写入三种格式的数据:string、int64、float32
tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
tf.train.Feature(int64_list=tf.train.Int64List(value=[feature.shape])) #整形数据写入形状
tf.train.Feature(float_list=tf.train.FloatList(value=[label])) #浮点型数据写入标签
通过feature建立example
tf.train.Example(features=tf.train.Features(feature=feature))
将example写入TFRecord文件
with tf.io.TFRecordWriter(tfrecord_file) as writer:
...
writer.write(example.SerializeToString())
TFRecord数据读取流程
1.读取TFRecord文件
2.定义特征结构
3.解析TFRecord中的数据和标签
代码实现案例
将猫狗数据集生成TFRecord文件并读取。
1.导入相关模块
import tensorflow as tf
import os
import matplotlib.pyplot as plt
2.准备文件路径和数据标签
train_cats_dir = '.../train/cat/'
train_dogs_dir = '.../train/dog/'
tfrecord_file = '.../train/train.tfrecord'
train_cat_filenames = [train_cats_dir + filename for filename in os.listdir(train_cats_dir)]
train_dog_filenames = [train_dogs_dir + filename for filename in os.listdir(train_dogs_dir)]
train_filenames = train_cat_filenames + train_dog_filenames
#将cat类标签设置为0,dog类标签设置为1
train_labels = [0] * len(train_cat_filenames) + [1] * len(train_dog_filenames)
3.将数据转换为特征并写入TFRecord文件
#将数据转换为特征
with tf.io.TFRecordWriter(tfrecord_file) as writer:
for filename, label in zip(train_filenames, train_labels):
image = open(filename, 'rb').read() # 读取数据集图片到内存,image 为一个 Byte 类型的字符串
feature = { # 建立 tf.train.Feature 字典
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])), # 图片是一个 Bytes 对象
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])) # 标签是一个 Int 对象
}
example = tf.train.Example(features=tf.train.Features(feature=feature)) # 通过字典建立 Example
# 将特征写入TFRecord文件中
writer.write(example.SerializeToString()) # 将Example序列化并写入 TFRecord 文件
4.读取TFRecord文件
raw_dataset = tf.data.TFRecordDataset(tfrecord_file) # 读取 TFRecord 文件
5.定义特征结构
#定义特征结构
feature_description = { # 定义Feature结构,告诉解码器每个Feature的类型是什么
'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64),
}
6.解析TFRecord中的数据和标签
#解析TFRecord中的数据和标签
def _parse_example(example_string): # 将 TFRecord 文件中的每一个序列化的 tf.train.Example 解码
feature_dict = tf.io.parse_single_example(example_string, feature_description)
feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image']) # 解码JPEG图片
return feature_dict['image'], feature_dict['label']
dataset = raw_dataset.map(_parse_example)
7.可视化
for image, label in dataset:
plt.title('cat' if label == 0 else 'dog')
plt.imshow(image)
plt.axis('off')
plt.show()