本文主要介绍谷歌官方在Github TensorFlow中开源的官方代码DeepLab在读取TFRecord格式数据集所使用的方法。
首先,需要将整个工程拉取到本地的workspace。
1. 源码地址:https://github.com/tensorflow/models/tree/master/research/deeplab
2. 将源代码拉取到自己的workspace中。
git clone https://github.com/tensorflow/models.git
3. 测试是否安装配置成功。
# From tensorflow/models/research/
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
# From tensorflow/models/research/
python deeplab/model_test.py
读取数据集部分的代码出现在以下文件中,以PASCAL_VOC的TFRecord格式数据集进行训练过程为例:
(1)deeplab/train.py
(2)deeplab/datasets/segmentation_dataset.py
(3)deeplab/utils/input_generator.py
1. 输入指令参数
在train.py中可看到以下代码,即需要输入3个参数:train_logdir、tf_initial_checkpoint、dataset_dir。
if __name__ == '__main__':
flags.mark_flag_as_required('train_logdir')
flags.mark_flag_as_required('tf_initial_checkpoint')
flags.mark_flag_as_required('dataset_dir')
tf.app.run()
其中
train_logdir="/deeplab/datasets/pascal_voc_seg/exp/train_on_train_set/train"(训练结束后的checkpoint存放路径)
tf_initial_checkpoint="/deeplab/datasets/cityscapes/deeplabv3_mnv2_pascal_trainval/ model.ckpt-30000.index"(预训练好的checkpoint路径)
dataset_dir="/deeplab/datasets/pascal_voc_seg/tfrecord"(数据集路径)
2. 通过指令输入的参数,获得一个slim.Dataset的实例
2.1 调用segmentation_dataset.py中的get_dataset()函数。
dataset = segmentation_dataset.get_dataset(
FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)
输入参数如下:
FLAGS.dataset= 'pascal_voc_seg'
FLAGS.train_split= 'train'
FLAGS.dataset_dir='/deeplab/datasets/pascal_voc_seg/tfrecord'(即在1中输入的dataset_dir参数)
2.2 在segmentation_dataset.py中的get_dataset()函数,定义如下:
def get_dataset(dataset_name, split_name, dataset_dir):
(1)首先,进行两个判断。输入的参数中,dataset_name必须是pascal_voc_seg、cityscapes、ade20k其中的一个,否则报错;接着获取数据集的基本信息,如果输入的split_name不是train、train_aug、trainval、val其中的一个,则报错。
if dataset_name not in _DATASETS_INFORMATION:
raise ValueError('The specified dataset is not supported yet.')
splits_to_sizes = _DATASETS_INFORMATION[dataset_name].splits_to_sizes
if split_name not in splits_to_sizes:
raise ValueError('data split name %s not recognized' % split_name)
PASCAL_VOC数据集的基本信息如下:
_PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor(
splits_to_sizes={
'train': 1464,
'train_aug': 10582,
'trainval': 2913,
'val': 1449,
},
num_classes=21,
ignore_label=255,
)
(2)接着获取得到num_classes = 21, ignore_label = 255。
num_classes = _DATASETS_INFORMATION[dataset_name].num_classes
ignore_label = _DATASETS_INFORMATION[dataset_name].ignore_label
(3)接下来,获得数据格式file_pattern,由两部分拼接而成。因为tfrecord格式的命名为train-*,所以file_pattern=/deeplab/datasets/pascal_voc_seg/tfrecord/train-*。
file_pattern = _FILE_PATTERN
file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
(4)声明TF-Examples的解码方式。
keys_to_features = {
'image/encoded': tf.FixedLenFeature(
(), tf.string, default_value=''),
'image/filename': tf.FixedLenFeature(
(), tf.string, default_value=''),
'image/format': tf.FixedLenFeature(
(), tf.string, default_value='jpeg'),
'image/height': tf.FixedLenFeature(
(), tf.int64, default_value=0),
'image/width': tf.FixedLenFeature(
(), tf.int64, default_value=0),
'image/segmentation/class/encoded': tf.FixedLenFeature(
(), tf.string, default_value=''),
'image/segmentation/class/format': tf.FixedLenFeature(
(), tf.string, default_value='png'),
}
items_to_handlers = {
'image': tfexample_decoder.Image(
image_key='image/encoded',
format_key='image/format',
channels=3),
'image_name': tfexample_decoder.Tensor('image/filename'),
'height': tfexample_decoder.Tensor('image/height'),
'width': tfexample_decoder.Tensor('image/width'),
'labels_class': tfexample_decoder.Image(
image_key='image/segmentation/class/encoded',
format_key='image/segmentation/class/format',
channels=1),
}
(5)将声明好的两个dict输入TFExampleDecoder。
decoder = tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)
(6)最后,返回一个slim.Dataset的实例。
return dataset.Dataset(
data_sources=file_pattern,
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=splits_to_sizes[split_name],
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
ignore_label=ignore_label,
num_classes=num_classes,
name=dataset_name,
multi_label=True)
其中具体的参数值如下:
file_pattern: /deeplab/datasets/pascal_voc_seg/tfrecord/train-*
tf.TFRecordReader: tf.TFRecordReader(读取方式)
decoder: decoder
splits_to_sizes[split_name]: 1464(samples的个数)
_ITEMS_TO_DESCRIPTIONS: 一个dict,包含一些描述,如下:
_ITEMS_TO_DESCRIPTIONS = {
'image': 'A color image of varying height and width.',
'labels_class': ('A semantic segmentation label whose size matches image.'
'Its values range from 0 (background) to num_classes.'),
}
ignore_label: 255
num_classes: 21(包含一个背景)
dataset_name: pascal_voc_seg
multi_label: True
该slim.dataset的实例,也即2.1中的train.py通过调用该函数得到的dataset。
3. 获得由tf.train.batch()生成的实例samples
3.1 在train.py中调用input_generator.py中的get()函数。
samples = input_generator.get(
dataset,
FLAGS.train_crop_size,
clone_batch_size,
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
min_scale_factor=FLAGS.min_scale_factor,
max_scale_factor=FLAGS.max_scale_factor,
scale_factor_step_size=FLAGS.scale_factor_step_size,
dataset_split=FLAGS.train_split,
is_training=True,
model_variant=FLAGS.model_variant)
输入参数如下:
dataset: 上一步获得的slim.Dataset实例dataset
FLAGS.train_crop_size: [513, 513]
clone_batch_size: 8,计算代码如下,train_batch_size=8, num_clones=1
clone_batch_size = FLAGS.train_batch_size // config.num_clones
FLAGS.min_resize_value: 未找到
FLAGS.max_resize_value: 未找到
FLAGS.resize_factor: 未找到
FLAGS.min_scale_factor: 0.5
FLAGS.max_scale_factor: 2.
FLAGS.scale_factor_step_size: 0.25
FLAGS.train_split: train
is_training: True
FLAGS.model_variant: xception_65
3.2 在input_generator.py中的get()函数,其定义如下:
def get(dataset,
crop_size,
batch_size,
min_resize_value=None,
max_resize_value=None,
resize_factor=None,
min_scale_factor=1.,
max_scale_factor=1.,
scale_factor_step_size=0,
num_readers=1,
num_threads=1,
dataset_split=None,
is_training=True,
model_variant=None):
(1)首先,两个判断,保证明确dataset的正确划分和model_variant的声明。
if dataset_split is None:
raise ValueError('Unknown dataset split.')
if model_variant is None:
tf.logging.warning('Please specify a model_variant. See '
'feature_extractor.network_map for supported model '
'variants.')
(2)生成一个slim.dataset_data_provider的实例,dataset为之前获得的slim.Dataset实例,num_readers = 1,num_epochs = None,shuffle = True。
data_provider = dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=num_readers,
num_epochs=None if is_training else 1,
shuffle=is_training)
(3)调用_get_data()函数。
image, label, image_name, height, width = _get_data(data_provider,
dataset_split)
def _get_data(data_provider, dataset_split):
_get_data()函数,通过slim.dataset_data_provider的get方法获取到image、height、width,接着获取到data_name。接下来,判断是否为训练/验证过程,若是训练/验证过程,则获取到label,否则label为None。最后,返回image、label、image_name、height、width这5个tensor给get()函数。
if common.LABELS_CLASS not in data_provider.list_items():
raise ValueError('Failed to find labels.')
image, height, width = data_provider.get(
[common.IMAGE, common.HEIGHT, common.WIDTH])
# Some datasets do not contain image_name.
if common.IMAGE_NAME in data_provider.list_items():
image_name, = data_provider.get([common.IMAGE_NAME])
else:
image_name = tf.constant('')
label = None
if dataset_split != common.TEST_SET:
label, = data_provider.get([common.LABELS_CLASS])
return image, label, image_name, height, width
(4)接着,继续在get()函数中。判断通过_get_data()函数返回的label是否为None,若不是None,则判断维度是否为[, , 1]。若是2维,则扩维;若是3维且第三维是否为1,则跳过,否则报错;最后将label维度设置为[None,None,1]。
if label is not None:
if label.shape.ndims == 2:
label = tf.expand_dims(label, 2)
elif label.shape.ndims == 3 and label.shape.dims[2] == 1:
pass
else:
raise ValueError('Input label shape must be [height, width], or '
'[height, width, 1].')
label.set_shape([None, None, 1])
(5)接着,调用input_process.py中的preprocess_image_and_label()函数,对训练过程中用到的image和label进行操作,比如resize和归一化等。
original_image, image, label = input_preprocess.preprocess_image_and_label(
image,
label,
crop_height=crop_size[0],
crop_width=crop_size[1],
min_resize_value=min_resize_value,
max_resize_value=max_resize_value,
resize_factor=resize_factor,
min_scale_factor=min_scale_factor,
max_scale_factor=max_scale_factor,
scale_factor_step_size=scale_factor_step_size,
ignore_label=dataset["ignore_label"],
is_training=is_training,
model_variant=model_variant)
其中,输入参数如下:
image: image
label: label
crop_height: 513
crop_width: 513
min_resize_value: None
max_resize_value: None
resize_factor: None
min_scale_factor: 0.5
max_scale_factor: 2.
scale_factor_step_size: 0.25
ignore_label: 255
is_training: True
model_variant: xception_65
返回值如下:
original_image: resize后的image
image: resize后的经过预处理的image
label: resize后的经过预处理的label
(6)接着,声明一个dict实例sample。分别将image, image_name, height, width传入。若label不是None,则将label也传入。若非训练过程,则将original_image也传入用于visualization过程。
sample = {
common.IMAGE: image,
common.IMAGE_NAME: image_name,
common.HEIGHT: height,
common.WIDTH: width
}
if label is not None:
sample[common.LABEL] = label
if not is_training:
sample[common.ORIGINAL_IMAGE] = original_image,
num_threads = 1
(7)最后,调用tf.train.batch(),返回一个samples。
return tf.train.batch(
sample,
batch_size=batch_size,
num_threads=num_threads,
capacity=32 * batch_size,
allow_smaller_final_batch=not is_training,
dynamic_pad=True)
4. 最后,在train.py中,调用slim.prefetch_queue.prefetch_queue()方法,生成输入队列
inputs_queue = prefetch_queue.prefetch_queue(
samples, capacity=128 * config.num_clones)
至此,读取数据集的过程结束。
1. 本文为作者原创,如需转载,请注明本文链接和作者ID:superkoma。
2. 创作本文目的是为了理解DeepLab v3读取数据集的主要流程,方便在自己的数据集上进行训练和验证测试等。一般大家会将自己的数据集转为TFRecord格式以满足输入要求,但是另一种思路是修改其源码读取数据集的部分,使其能够直接从一个包含图像路径的list中直接读取。作者后续会推出修改方式。