TensorFlow 2.1.0 使用 TFRecord 转存与读取图片

前言

当 NLP 玩家遇到一个 CV 图像分类的任务时,老实的说,我是有点懵逼的。。。

任务目标是,训练一个层数较少,结构较为简单的图像分类模型,使用其最后一层隐藏层输出,作为特征向量来表征图像。

之前都是使用 Keras 较多,于是本次准备借着这个简单的任务上手 TensorFlow 2.1 。


数据加载

Python generator 出现的问题

TensorFlow 2.1 自带的 tf.data.Dataset 处理训练数据十分好用,并且自带 shuffle,repeat,和划分 batch 的方法。可以通过python generator, numpy list, Tensor slices 等数据结构直接构成 Dataset。

我训练使用的数据是文档中的插图,5个类别共 10w 张。

起初我使用的方法是:构造一个 python generator,训练时,使用 tf 自带的 tf.io.read_file() 和 tf.image.decode_jpeg() 方法从磁盘中读取数据,再使用 tf.data.Dataset.from_generator 生成数据集。

但训练时发现这样的数据处理有着很大的问题:受制于generator 的读取数据速度,batch 数据生成的速度跟不上 GPU 的训练速度,导致 GPU 的利用率只有不到 10%,训练速度很慢。很慢。。很慢。。。

 

 TFRecord

这时我想到了 TFRecord 。TFRecord 可以将数据转存为二进制文件保存,这样在训练时读取数据就不会遇到以上的问题了。

使用 TFRecord 来进行数据处理,首先需要将原始图片数据转存为 TFRecord 格式:

def pares_image(pic):
'''
    图片预处理,并转为字符串
'''
    label = pares_label(pic)
    try:
        img = Image.open(pic)
        img = img.convert('RGB')
        img = img.resize((54,96))
        img_raw = img.tobytes()
    except:
        return None, None
    return img_raw, label


train_data_list = get_file_path(train_data_path)


writer = tf.io.TFRecordWriter('./data/test_data')
for data in tqdm(test_data_list):
    img_raw, label = pares_image(data)
    if (img_raw is not None) and (label != 'not valid'):
        exam = tf.train.Example(
            features = tf.train.Features(
                feature = {
                    'label': tf.train.Feature(int64_list=tf.train.Int64List (value=[int(label_2_idx[label])])),
                    'data' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
                }
            )
        )
        writer.write (exam.SerializeToString())
writer.close()  

文件读取

接下来读取 TFRecord 文件,加载进 tf.data,Dataset 

train_reader = tf.data.TFRecordDataset('./data/train_data')
test_reader = tf.data.TFRecordDataset('./data/test_data')
valid_reader = tf.data.TFRecordDataset('./data/valid_data')

feature_description = {
    'data' : tf.io.FixedLenFeature([], tf.string),
    'label': tf.io.FixedLenFeature([] , tf.int64, default_value=-1)
}
def _parse_function (exam_proto):
    temp = tf.io.parse_single_example (exam_proto, feature_description)
    img = tf.io.decode_raw(temp['data'], tf.uint8)
    img = tf.reshape(img, [54, 96, 3])
    img = img / 255
    label = temp['label']
    return (img, label)


train_dataset = train_reader.repeat(5).shuffle(12800, reshuffle_each_iteration=True).map(_parse_function).batch(128)
test_dataset = test_reader.repeat(5).shuffle(12800, reshuffle_each_iteration=True).map(_parse_function).batch(128)
valid_dataset = test_reader.repeat(5).shuffle(12800, reshuffle_each_iteration=True).map(_parse_function).batch(128)

读取文件时,需要重新将二进制的图像数据重新 decode 为 Tensor,并进行数值归一化处理。

你可能感兴趣的:(TensorFlow,Python,tensorflow,TFRecord)