TensorFlow笔记_TFRecord生成和读取

学习视频链接:TFRecord生成和读取-深度学习框架应用开发-TensorFlow 2.0 | 百科荣创在线学习平台

TFRecord格式介绍

正常读取数据集是从硬盘直接读取数据,这样需要先将数据读取出来再进行训练,这意味着需要通过IO对硬盘上的数据再次进行读取,再把数据放入内存中,之后再送入神经网络进行运算,由于读取数据需要等待时间,这样就造成了大部分资源的浪费,导致训练时间过长,基于此,tensorflow官方推荐了TFRecord读取数据的方法。

TFRecord是一种tensorflow的标准文件格式,实质是二进制文件,遵循protocol buffer协议。TFRecord文件方便复制移动,能够很好的利用内存,无需单独标记文件,适用于大量数据的顺序读取。

TFRecord先把文件读取出来,把每一个文件堆成队列,把文件中的数据按照队列写入TFRecord文件中,这样读取数据就相当于读取内存,加快了数据读入的速度。

TFRecord格式使用

1.指定原始数据的文件列表

2.创建文件列表队列

3.从文件读取数据

4.整理成Batch作为神经网络输入

TFRecord文件生成流程

1.获取每个文件的数据和标签

2.将数据和标签转换为特征

3.将特征写入TFRecord文件

创建TFRecord

TFRecord支持写入三种格式的数据:string、int64、float32

tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))  
tf.train.Feature(int64_list=tf.train.Int64List(value=[feature.shape]))  #整形数据写入形状
tf.train.Feature(float_list=tf.train.FloatList(value=[label]))  #浮点型数据写入标签

通过feature建立example

tf.train.Example(features=tf.train.Features(feature=feature))

将example写入TFRecord文件

with tf.io.TFRecordWriter(tfrecord_file) as writer:
    ...
    writer.write(example.SerializeToString())

TFRecord数据读取流程

1.读取TFRecord文件

2.定义特征结构

3.解析TFRecord中的数据和标签

代码实现案例

将猫狗数据集生成TFRecord文件并读取。

TensorFlow笔记_TFRecord生成和读取_第1张图片

1.导入相关模块

import tensorflow as tf
import os
import matplotlib.pyplot as plt

 2.准备文件路径和数据标签

train_cats_dir = '.../train/cat/'
train_dogs_dir = '.../train/dog/'
tfrecord_file = '.../train/train.tfrecord'

train_cat_filenames = [train_cats_dir + filename for filename in os.listdir(train_cats_dir)]
train_dog_filenames = [train_dogs_dir + filename for filename in os.listdir(train_dogs_dir)]
train_filenames = train_cat_filenames + train_dog_filenames

#将cat类标签设置为0,dog类标签设置为1
train_labels = [0] * len(train_cat_filenames) + [1] * len(train_dog_filenames)

3.将数据转换为特征并写入TFRecord文件

#将数据转换为特征
with tf.io.TFRecordWriter(tfrecord_file) as writer:
    for filename, label in zip(train_filenames, train_labels):
        image = open(filename, 'rb').read()     # 读取数据集图片到内存,image 为一个 Byte 类型的字符串
        feature = {                             # 建立 tf.train.Feature 字典
            'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),  # 图片是一个 Bytes 对象
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))   # 标签是一个 Int 对象
        }

        example = tf.train.Example(features=tf.train.Features(feature=feature))       # 通过字典建立 Example
        # 将特征写入TFRecord文件中
        writer.write(example.SerializeToString())   # 将Example序列化并写入 TFRecord 文件

4.读取TFRecord文件

raw_dataset = tf.data.TFRecordDataset(tfrecord_file)    # 读取 TFRecord 文件

5.定义特征结构

#定义特征结构
feature_description = {    # 定义Feature结构,告诉解码器每个Feature的类型是什么
    'image': tf.io.FixedLenFeature([], tf.string),
    'label': tf.io.FixedLenFeature([], tf.int64),
}

6.解析TFRecord中的数据和标签

#解析TFRecord中的数据和标签
def _parse_example(example_string):   # 将 TFRecord 文件中的每一个序列化的 tf.train.Example 解码
    feature_dict = tf.io.parse_single_example(example_string, feature_description)
    feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image'])    # 解码JPEG图片
    return feature_dict['image'], feature_dict['label']

dataset = raw_dataset.map(_parse_example)

7.可视化

for image, label in dataset:
    plt.title('cat' if label == 0 else 'dog')
    plt.imshow(image)
    plt.axis('off')
    plt.show()

TensorFlow笔记_TFRecord生成和读取_第2张图片

 

你可能感兴趣的:(深度学习,tensorflow,人工智能,python)