对于数据量很大的数据集, 直接读入内存可能会放不下, 建议的做法是把全部数据转换成tfrecord的格式, 方便神经网络读取数据, 并且从tfrecord中读取数据的话tensorflow专门做过优化, 能加快读取速度.
参考资料: 官方tfrecord读写教程
方法1: 直接以二进制bytes读取图片, 然后放进tfrecord中, 但是这样对bytes没法做修改, 比如有时候label需要进行map, 这时候就要用方法2.
import tensorflow as tf
# 把一个byte数据转换成一个bytes_list
def _bytes_list_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 把一对features和label转换成一个tfexample
def image_seg_to_tfexample(image_data, seg_data):
return tf.train.Example(features=tf.train.Features(feature={
'image': _bytes_list_feature(image_data),
'label': _bytes_list_feature(seg_data),
}))
# 解析并读取image和label成二进制byte类型, image_data = open(image_filename, 'rb').read()有相同的效果
image_data = tf.gfile.GFile(image_filename, 'rb').read() # type(image_data)为bytes
seg_data = tf.gfile.GFile(seg_filename, 'rb').read()
# image_data = tf.read_file(image_filename) 也行, type(image_data)也是bytes
with tf.python_io.TFRecordWriter(output_filename) as writer:
example = image_seg_to_tfexample(image_data,seg_data)
# 把tfexample写入tfrecord中
writer.write(example.SerializeToString())
方法2: 不直接把图片读取成bytes, 而是转换成ndarray, 这样可以对ndarray进行修改, 再写入tfrecord中.
from PIL import Image
import numpy as np
import tensorflow as tf
# 读取已经保存好的字典, 后面用于map
with open('/home/steven/deeplab_v3+_project/deeplab_v3+_tensorflow_from_rishizek/map_dictionary.pickle', 'rb') as f:
map_dict = pickle.load(f)
# 读取image成ndarray,注意读取的时候dtype设置为np.uint8, 因为像素值在0-255之间
image_data = np.array(Image.open(image_filename)).astype(np.uint8)
# 将image从ndarray变成bytes, 方便写入tfrecord
image_data = image_data.tostring()
# 读取label成ndarray,先不转换np.uint8, 因为map可能改变dtype
seg_data = np.array(Image.open(seg_filename))
# 对ndarray做map
seg_data_mapped = np.vectorize(map_dict.get)(seg_data)
# 将seg_data_mapped从ndarray变成bytes, 方便写入tfrecord, 注意先把数据也转换成np.uint8再变成bytes
seg_data = seg_data_mapped.astype(np.uint8).tostring()
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
# image_seg_to_tfexample()函数见方法1
example = image_seg_to_tfexample(image_data,seg_data)
tfrecord_writer.write(example.SerializeToString())
# 返回一个list, 包含所有要输入的tfrecord文件
def get_filenames(is_training, data_dir):
if is_training:
return [os.path.join(data_dir, 'nonzeros_train-00000-of-00004.tfrecord'),
os.path.join(data_dir, 'nonzeros_train-00001-of-00004.tfrecord'),
os.path.join(data_dir, 'nonzeros_train-00002-of-00004.tfrecord'),
os.path.join(data_dir, 'nonzeros_train-00003-of-00004.tfrecord')]
else:
return [os.path.join(data_dir, 'nonzeros_valid-00000-of-00004.tfrecord'),
os.path.join(data_dir, 'nonzeros_valid-00001-of-00004.tfrecord'),
os.path.join(data_dir, 'nonzeros_valid-00002-of-00004.tfrecord'),
os.path.join(data_dir, 'nonzeros_valid-00003-of-00004.tfrecord')]
# 读取所有tfrecord文件得到dataset
dataset = tf.data.TFRecordDataset(get_filenames(is_training,data_dir))
# 解析dataset的函数, 直接把bytes转换回image, 对应方法1
def parse_record(raw_record):
# 按什么格式写入的, 就要以同样的格式输出
keys_to_features = {
'image': tf.FixedLenFeature((), tf.string),
'label': tf.FixedLenFeature((), tf.string),
}
# 按照keys_to_features解析二进制的
parsed = tf.parse_single_example(raw_record, keys_to_features)
image = tf.image.decode_image(tf.reshape(parsed['image'], shape=[]), 1)
image = tf.to_float(tf.image.convert_image_dtype(image, dtype=tf.uint8))
image.set_shape([None, None, 1])
label = tf.image.decode_image(tf.reshape(parsed['label'], shape=[]), 1)
label = tf.to_int32(tf.image.convert_image_dtype(label, dtype=tf.uint8))
label.set_shape([None, None, 1])
return image, label
# 直接把bytes类型的ndarray解析回来, 用decode_raw(),对应方法2
def parse_record(raw_record):
keys_to_features = {
'image': tf.FixedLenFeature((), tf.string),
'label': tf.FixedLenFeature((), tf.string),
}
parsed = tf.parse_single_example(raw_record, keys_to_features)
image = tf.decode_raw(parsed['image'], tf.uint8)
image = tf.to_float(image)
image = tf.reshape(image, [256,256,1])
label = tf.decode_raw(parsed['label'], tf.uint8)
label = tf.to_int32(label)
label = tf.reshape(label, [256,256,1])
return image, label
# 对dataset中的每条数据, 应用parse_record函数, 得到解析后的新的dataset
dataset = dataset.map(parse_record)
# 对dataset中的每条数据, 应用lambda函数, 输入image, label, 用preprocess_image()函数(省略没写)处理,得到新的dataset
dataset = dataset.map(lambda image, label: preprocess_image(image, label, is_training))
# dataset还可以做repeat(), shuffle(), batch()等处理
dataset = dataset.shuffle(buffer_size).repeat(num_epochs).batch(batch_size)
# 每次sess.run(images, labels)得到一个batch_size的images和labels
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
def eval_input_fn(image_filenames, label_filenames=None, batch_size=1):
# Reads an image from a file, decodes it into a dense tensor
def _parse_function(filename, is_label):
if not is_label:
image_filename, label_filename = filename, None
else:
image_filename, label_filename = filename
# 这里与上面1和2中读取与解析图片的过程类似, 区别在于tf.readfile()得到的bytes文件没有放入tfrecord中, 而是通过tf.image.decode_image()直接解析成tensor
image_string = tf.read_file(image_filename)
image = tf.image.decode_image(image_string)
image = tf.to_float(tf.image.convert_image_dtype(image, dtype=tf.uint8))
image.set_shape([None, None, 3])
if not is_label:
return image
else:
# 读取与解析label
label_string = tf.read_file(label_filename)
label = tf.image.decode_image(label_string)
label = tf.to_int32(tf.image.convert_image_dtype(label, dtype=tf.uint8))
label.set_shape([None, None, 1])
return image, label
if label_filenames is None:
input_filenames = image_filenames
else:
input_filenames = (image_filenames, label_filenames)
# input_filenames是一个文件名组成的list或者一个由两个list组成的元组, 这里通过tf.data.Dataset.from_tensor_slices()直接获得文件名组成的dataset
dataset = tf.data.Dataset.from_tensor_slices(input_filenames)
# 通过map函数, 解析dataset中的文件名形成一个新的dataset
if label_filenames is None:
dataset = dataset.map(lambda x: _parse_function(x, False))
else:
dataset = dataset.map(lambda x, y: _parse_function((x, y), True))
dataset = dataset.prefetch(batch_size)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
if label_filenames is None:
images = iterator.get_next()
labels = None
else:
images, labels = iterator.get_next()
return images, labels