话不多说,干就完了。
在上一篇博文中简单介绍了一下TFRecord数据格式的生成和加载,本篇介绍另一种数据加载方式DataSet。
不管何种类型的数据加载方式都遵循一下几个步骤:
下面就讲一下DataSet加载图像数据的使用方法:
前提:本文使用的测试数据集只有6张图片,猫狗各三张,数据集目录结构如下:
1)、读取原始数据信息,并保存到txt文件中,保存格式为:
图像地址 图像标签(从0开始,依次递增,方便后续进行one-hot编码)
import os
current_path = os.path.dirname(os.path.abspath(__file__))
datasets_path = os.path.join(current_path, "dataset")
filename = os.path.join(current_path, "dataset_image_list.txt")
if os.path.exists(filename) is False:
with open(filename, "w") as f_obj:
for cls_index, cls_name in enumerate(os.listdir(datasets_path)):
print(cls_name)
if os.path.isdir(os.path.join(datasets_path, cls_name)):
print("#" * 40, cls_name, "#" * 40)
for img_index, img_name in enumerate(os.listdir(os.path.join(datasets_path, cls_name))):
if os.path.isfile(os.path.join(datasets_path, cls_name, img_name)):
img_path = os.path.join(datasets_path, cls_name, img_name)
f_obj.write(img_path + "\t" + str(cls_index) + "\n")
else:
print("file exists")
将图像数据信息先读取到txt文件中的好处还有一个是,后面对图像进行处理,不管是生成TFRecord还是DataSet的过程中,都方便对数据集进行重排shuffle,不然的话就是连续的读取相同类型标签的图像,在网络训练时还需要再次进行重排shuffle。
2)、加载第一步生成的txt文件,将其中的图像信息保存到TFRecord文件中
import os
import cv2
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
current_path = os.path.dirname(os.path.abspath(__file__))
datasets_path = os.path.join(current_path, "dataset_image_list.txt")
train_tfreocrd_filename = os.path.join(current_path, "tfrecord_files", "cat_and_dog_train.tfrecords")
validation_tfreocrd_filename = os.path.join(current_path, "tfrecord_files", "cat_and_dog_validation.tfrecords")
image_size = 224
image_channel = 3
def generate_tfrecord():
if os.path.exists(train_tfreocrd_filename):
return
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
train_writer = tf.python_io.TFRecordWriter(path=train_tfreocrd_filename)
validation_writer = tf.python_io.TFRecordWriter(path=validation_tfreocrd_filename)
with open(datasets_path, "r") as f_obj:
image_list = f_obj.readlines()
image_list_len = len(image_list)
print(image_list_len)
permutation = np.random.permutation(image_list_len)
print(permutation)
img_list = []
for i in permutation:
img_list.append(image_list[i])
for img_index, img_info in enumerate(img_list):
img_path, img_class = img_info.split()
img_class = int(img_class)
print(img_path, img_class)
if os.path.isfile(img_path):
img = cv2.imread(filename=img_path)
if img.ndim == image_channel:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (image_size, image_size), cv2.INTER_AREA)
img_pixles = img.shape[0] * img.shape[1] * img.shape[2]
# plt.imshow(img)
# plt.show()
else:
continue
image_raw = img.tostring()
# print(image_raw)
example = tf.train.Example(
features=tf.train.Features(
feature={
"pixels": _int64_feature(img_pixles),
"label": _int64_feature(img_class),
"image_raw": _bytes_feature(image_raw)
}
)
)
if img_index % 2 == 0:
validation_writer.write(example.SerializeToString())
else:
train_writer.write(example.SerializeToString())
train_writer.close()
validation_writer.close()
def read_from_tfrecord(sess):
reader = tf.TFRecordReader()
filename_queue = tf.train.string_input_producer(
string_tensor=[train_tfreocrd_filename, validation_tfreocrd_filename])
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
"pixels": tf.FixedLenFeature([], tf.int64),
"label": tf.FixedLenFeature([], tf.int64),
"image_raw": tf.FixedLenFeature([], tf.string)
}
)
image = tf.decode_raw(features['image_raw'], tf.uint8)
label = tf.cast(features["label"], tf.int64)
pixels = tf.cast(features["pixels"], tf.int32)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(6):
image_raw, image_label, _ = sess.run([image, label, pixels])
image_raw = np.reshape(image_raw, (image_size, image_size, image_channel))
print(image_raw.shape)
plt.imshow(image_raw)
plt.xlabel(str(image_label))
plt.show()
if __name__ == '__main__':
with tf.Session() as sess:
generate_tfrecord()
read_from_tfrecord(sess=sess)
3)、基于第二步生成的TFRecord文件构造DataSet
import os
import tensorflow as tf
import cv2
def _parse_tfrecord_features(serialized_feature):
features = {
"pixels": tf.FixedLenFeature([], tf.int64),
"label": tf.FixedLenFeature([], tf.int64),
"image_raw": tf.FixedLenFeature([], tf.string)
}
image_features = tf.parse_single_example(
serialized_feature,
features=features
)
image_raw = tf.decode_raw(bytes=image_features["image_raw"], out_type=tf.uint8)
image_raw = tf.image.convert_image_dtype(image_raw, tf.float32)
labels = tf.cast(image_features["label"], tf.int64)
image_labels = tf.one_hot(labels, depth=2)
return image_raw, image_labels
def _image_reshape(image_data, image_label):
return tf.reshape(image_data, (224, 224, 3)), image_label
current_path = os.path.dirname(os.path.abspath(__file__))
train_tfrecords = os.path.join(current_path, "tfrecord_files", "cat_and_dog_train.tfrecords")
validation_tfrecords = os.path.join(current_path, "tfrecord_files", "cat_and_dog_validation.tfrecords")
repeat_count = 3
batch_size = 16
shuffle_buffer_size = 4096
train_epoch = 10
image_size = 224
image_channel = 3
train_tfrecord_datasets = tf.data.TFRecordDataset(filenames=train_tfrecords)
train_tfrecord_datasets = train_tfrecord_datasets.map(_parse_tfrecord_features)
train_tfrecord_datasets = train_tfrecord_datasets.repeat(count=repeat_count)
train_tfrecord_datasets = train_tfrecord_datasets.map(_image_reshape)
train_tfrecord_datasets = train_tfrecord_datasets.shuffle(buffer_size=shuffle_buffer_size,
reshuffle_each_iteration=True)
train_tfrecord_datasets_batch = train_tfrecord_datasets.batch(batch_size=batch_size)
validation_tfrecord_datasets = tf.data.TFRecordDataset(filenames=validation_tfrecords)
validation_tfrecord_datasets = validation_tfrecord_datasets.map(_parse_tfrecord_features)
validation_tfrecord_datasets = validation_tfrecord_datasets.map(_image_reshape)
validation_tfrecord_datasets = validation_tfrecord_datasets.shuffle(buffer_size=shuffle_buffer_size, reshuffle_each_iteration=True)
validation_tfrecord_datasets_batch = validation_tfrecord_datasets.batch(batch_size=batch_size)
iterator = tf.data.Iterator.from_structure(
output_types=train_tfrecord_datasets_batch.output_types,
output_shapes=train_tfrecord_datasets_batch.output_shapes
)
train_init_op = iterator.make_initializer(dataset=train_tfrecord_datasets_batch)
validation_init_op = iterator.make_initializer(dataset=validation_tfrecord_datasets_batch)
next_batch = iterator.get_next()
if __name__ == '__main__':
with tf.Session() as sess:
for epoch in range(train_epoch):
sess.run(train_init_op)
print("#" * 30, "training", "#" * 30)
while True:
try:
train_data_batch, train_label_batch = sess.run(next_batch)
print(train_label_batch)
except tf.errors.OutOfRangeError as e:
# print(e)
break
print("#" * 30, "validation", "#" * 30)
sess.run(validation_init_op)
while True:
try:
validation_data_batch, validation_label_batch = sess.run(next_batch)
print(validation_label_batch)
except tf.errors.OutOfRangeError as e:
# print(e)
break
其中几个要点需要说明一下:
1)、_parse_tfrecord_features函数用于解析TFRecord文件中图像信息,也就是需要将TFRecord中保存的图像矩阵和图像标签one-hot编码解析出来
2)、_image_reshape函数是将解析后的图像矩阵还原为3维(224,224,3)的形状
3)、train_tfrecord_datasets.map(_parse_tfrecord_features)和train_tfrecord_datasets.map(_image_reshape)
函数是分别对dataset中的每个元素应用一次_parse_tfrecord_features和_image_reshape,从而符合神经网络输入的DataSet
4)、由于训练数据和验证数据具有相同的结构,所以tf.data.Iterator.from_structure来构造迭代器对象,注意用的是train_tfrecord_datasets_batch
iterator = tf.data.Iterator.from_structure(
output_types=train_tfrecord_datasets_batch.output_types,
output_shapes=train_tfrecord_datasets_batch.output_shapes
)
5)、以下两步是对训练数据Dataset和验证数据DataSet进行初始化
train_init_op = iterator.make_initializer(dataset=train_tfrecord_datasets_batch)
validation_init_op = iterator.make_initializer(dataset=validation_tfrecord_datasets_batch)
6)、当需要使用训练数据时,先执行sess.run(train_init_op)进行初始化,然后再执行train_data_batch, train_label_batch = sess.run(next_batch)即可获得训练用的batch;同理,需要使用验证数据时,先执行sess.run(validation_init_op),然后再执行validation_data_batch, validation_label_batch = sess.run(next_batch)即可获得验证用的batch.
注意:随着iterator不断的取next,当取到最后的时候会抛出tf.errors.OutOfRangeError异常,表示数据已经取完了,此时如果想要进行下一轮的取数据则需要执行相应的初始化操作(train_init_op或validation_init_op)_batch)即可获得验证用的batch.
注意:随着iterator不断的取next,当取到最后的时候会抛出tf.errors.OutOfRangeError异常,表示数据已经取完了,此时如果想要进行下一轮的取数据则需要执行相应的初始化操作(train_init_op或validation_init_op)