Semantic Segmentation DeepLab v3 读取数据集(TFRecord)代码详解

本文主要介绍谷歌官方在Github TensorFlow中开源的官方代码DeepLab在读取TFRecord格式数据集所使用的方法。

配置DeepLab v3

首先,需要将整个工程拉取到本地的workspace。

1. 源码地址:https://github.com/tensorflow/models/tree/master/research/deeplab

2. 将源代码拉取到自己的workspace中。

git clone https://github.com/tensorflow/models.git

3. 测试是否安装配置成功。

# From tensorflow/models/research/
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
# From tensorflow/models/research/
python deeplab/model_test.py

读取数据集代码分析

读取数据集部分的代码出现在以下文件中,以PASCAL_VOC的TFRecord格式数据集进行训练过程为例:

(1)deeplab/train.py

(2)deeplab/datasets/segmentation_dataset.py

(3)deeplab/utils/input_generator.py

1. 输入指令参数

在train.py中可看到以下代码,即需要输入3个参数:train_logdir、tf_initial_checkpoint、dataset_dir。

if __name__ == '__main__':
  flags.mark_flag_as_required('train_logdir')
  flags.mark_flag_as_required('tf_initial_checkpoint')
  flags.mark_flag_as_required('dataset_dir')
  tf.app.run()

其中

train_logdir="/deeplab/datasets/pascal_voc_seg/exp/train_on_train_set/train"(训练结束后的checkpoint存放路径)

tf_initial_checkpoint="/deeplab/datasets/cityscapes/deeplabv3_mnv2_pascal_trainval/ model.ckpt-30000.index"(预训练好的checkpoint路径)

dataset_dir="/deeplab/datasets/pascal_voc_seg/tfrecord"(数据集路径)

2. 通过指令输入的参数,获得一个slim.Dataset的实例

2.1 调用segmentation_dataset.py中的get_dataset()函数。

dataset = segmentation_dataset.get_dataset(
     FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)

输入参数如下:

FLAGS.dataset= 'pascal_voc_seg'

FLAGS.train_split= 'train'

FLAGS.dataset_dir='/deeplab/datasets/pascal_voc_seg/tfrecord'(即在1中输入的dataset_dir参数)

2.2 在segmentation_dataset.py中的get_dataset()函数,定义如下:

def get_dataset(dataset_name, split_name, dataset_dir):

(1)首先,进行两个判断。输入的参数中,dataset_name必须是pascal_voc_seg、cityscapes、ade20k其中的一个,否则报错;接着获取数据集的基本信息,如果输入的split_name不是train、train_aug、trainval、val其中的一个,则报错。

if dataset_name not in _DATASETS_INFORMATION:
    raise ValueError('The specified dataset is not supported yet.')

  splits_to_sizes = _DATASETS_INFORMATION[dataset_name].splits_to_sizes

  if split_name not in splits_to_sizes:
    raise ValueError('data split name %s not recognized' % split_name)

PASCAL_VOC数据集的基本信息如下:

_PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor(
    splits_to_sizes={
        'train': 1464,
        'train_aug': 10582,
        'trainval': 2913,
        'val': 1449,
    },
    num_classes=21,
    ignore_label=255,
)

(2)接着获取得到num_classes = 21, ignore_label = 255。

num_classes = _DATASETS_INFORMATION[dataset_name].num_classes
  ignore_label = _DATASETS_INFORMATION[dataset_name].ignore_label

(3)接下来,获得数据格式file_pattern,由两部分拼接而成。因为tfrecord格式的命名为train-*,所以file_pattern=/deeplab/datasets/pascal_voc_seg/tfrecord/train-*。

file_pattern = _FILE_PATTERN
file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

(4)声明TF-Examples的解码方式。

keys_to_features = {
      'image/encoded': tf.FixedLenFeature(
          (), tf.string, default_value=''),
      'image/filename': tf.FixedLenFeature(
          (), tf.string, default_value=''),
      'image/format': tf.FixedLenFeature(
          (), tf.string, default_value='jpeg'),
      'image/height': tf.FixedLenFeature(
          (), tf.int64, default_value=0),
      'image/width': tf.FixedLenFeature(
          (), tf.int64, default_value=0),
      'image/segmentation/class/encoded': tf.FixedLenFeature(
          (), tf.string, default_value=''),
      'image/segmentation/class/format': tf.FixedLenFeature(
          (), tf.string, default_value='png'),
  }
  items_to_handlers = {
      'image': tfexample_decoder.Image(
          image_key='image/encoded',
          format_key='image/format',
          channels=3),
      'image_name': tfexample_decoder.Tensor('image/filename'),
      'height': tfexample_decoder.Tensor('image/height'),
      'width': tfexample_decoder.Tensor('image/width'),
      'labels_class': tfexample_decoder.Image(
          image_key='image/segmentation/class/encoded',
          format_key='image/segmentation/class/format',
          channels=1),
  }

(5)将声明好的两个dict输入TFExampleDecoder。

decoder = tfexample_decoder.TFExampleDecoder(
      keys_to_features, items_to_handlers)

(6)最后,返回一个slim.Dataset的实例。

return dataset.Dataset(
      data_sources=file_pattern,
      reader=tf.TFRecordReader,
      decoder=decoder,
      num_samples=splits_to_sizes[split_name],
      items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
      ignore_label=ignore_label,
      num_classes=num_classes,
      name=dataset_name,
      multi_label=True)

其中具体的参数值如下:

file_pattern: /deeplab/datasets/pascal_voc_seg/tfrecord/train-*

tf.TFRecordReader: tf.TFRecordReader(读取方式)

decoder: decoder

splits_to_sizes[split_name]: 1464(samples的个数)

_ITEMS_TO_DESCRIPTIONS: 一个dict,包含一些描述,如下:

_ITEMS_TO_DESCRIPTIONS = {
    'image': 'A color image of varying height and width.',
    'labels_class': ('A semantic segmentation label whose size matches image.'
                     'Its values range from 0 (background) to num_classes.'),
}

ignore_label: 255

num_classes: 21(包含一个背景)

dataset_name: pascal_voc_seg

multi_label: True

 

该slim.dataset的实例,也即2.1中的train.py通过调用该函数得到的dataset。

 

3. 获得由tf.train.batch()生成的实例samples

3.1 在train.py中调用input_generator.py中的get()函数。

samples = input_generator.get(
          dataset,
          FLAGS.train_crop_size,
          clone_batch_size,
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          dataset_split=FLAGS.train_split,
          is_training=True,
          model_variant=FLAGS.model_variant)

输入参数如下:

dataset: 上一步获得的slim.Dataset实例dataset

FLAGS.train_crop_size: [513, 513]

clone_batch_size: 8,计算代码如下,train_batch_size=8, num_clones=1

clone_batch_size = FLAGS.train_batch_size // config.num_clones

FLAGS.min_resize_value: 未找到

FLAGS.max_resize_value: 未找到

FLAGS.resize_factor: 未找到

FLAGS.min_scale_factor: 0.5

FLAGS.max_scale_factor: 2.

FLAGS.scale_factor_step_size: 0.25

FLAGS.train_split: train

is_training: True

FLAGS.model_variant: xception_65

3.2 在input_generator.py中的get()函数,其定义如下:

def get(dataset,
        crop_size,
        batch_size,
        min_resize_value=None,
        max_resize_value=None,
        resize_factor=None,
        min_scale_factor=1.,
        max_scale_factor=1.,
        scale_factor_step_size=0,
        num_readers=1,
        num_threads=1,
        dataset_split=None,
        is_training=True,
        model_variant=None):

(1)首先,两个判断,保证明确dataset的正确划分和model_variant的声明。

if dataset_split is None:
    raise ValueError('Unknown dataset split.')
  if model_variant is None:
    tf.logging.warning('Please specify a model_variant. See '
                       'feature_extractor.network_map for supported model '
                       'variants.')

(2)生成一个slim.dataset_data_provider的实例,dataset为之前获得的slim.Dataset实例,num_readers = 1,num_epochs = None,shuffle = True。

data_provider = dataset_data_provider.DatasetDataProvider(
      dataset,
      num_readers=num_readers,
      num_epochs=None if is_training else 1,
      shuffle=is_training)

(3)调用_get_data()函数。

image, label, image_name, height, width = _get_data(data_provider,
                                                      dataset_split)
def _get_data(data_provider, dataset_split):

_get_data()函数,通过slim.dataset_data_provider的get方法获取到image、height、width,接着获取到data_name。接下来,判断是否为训练/验证过程,若是训练/验证过程,则获取到label,否则label为None。最后,返回image、label、image_name、height、width这5个tensor给get()函数。

if common.LABELS_CLASS not in data_provider.list_items():
    raise ValueError('Failed to find labels.')

  image, height, width = data_provider.get(
      [common.IMAGE, common.HEIGHT, common.WIDTH])

  # Some datasets do not contain image_name.
  if common.IMAGE_NAME in data_provider.list_items():
    image_name, = data_provider.get([common.IMAGE_NAME])
  else:
    image_name = tf.constant('')

  label = None
  if dataset_split != common.TEST_SET:
    label, = data_provider.get([common.LABELS_CLASS])

  return image, label, image_name, height, width

(4)接着,继续在get()函数中。判断通过_get_data()函数返回的label是否为None,若不是None,则判断维度是否为[, , 1]。若是2维,则扩维;若是3维且第三维是否为1,则跳过,否则报错;最后将label维度设置为[None,None,1]。

if label is not None:
    if label.shape.ndims == 2:
      label = tf.expand_dims(label, 2)
    elif label.shape.ndims == 3 and label.shape.dims[2] == 1:
      pass
    else:
      raise ValueError('Input label shape must be [height, width], or '
                       '[height, width, 1].')

    label.set_shape([None, None, 1])

(5)接着,调用input_process.py中的preprocess_image_and_label()函数,对训练过程中用到的image和label进行操作,比如resize和归一化等。

original_image, image, label = input_preprocess.preprocess_image_and_label(
      image,
      label,
      crop_height=crop_size[0],
      crop_width=crop_size[1],
      min_resize_value=min_resize_value,
      max_resize_value=max_resize_value,
      resize_factor=resize_factor,
      min_scale_factor=min_scale_factor,
      max_scale_factor=max_scale_factor,
      scale_factor_step_size=scale_factor_step_size,
      ignore_label=dataset["ignore_label"],
      is_training=is_training,
      model_variant=model_variant)

其中,输入参数如下:

image: image

label: label

crop_height: 513

crop_width: 513

min_resize_value: None

max_resize_value: None

resize_factor: None

min_scale_factor: 0.5

max_scale_factor: 2.

scale_factor_step_size: 0.25

ignore_label: 255

is_training: True

model_variant: xception_65

返回值如下:

original_image: resize后的image

image: resize后的经过预处理的image

label: resize后的经过预处理的label

(6)接着,声明一个dict实例sample。分别将image, image_name, height, width传入。若label不是None,则将label也传入。若非训练过程,则将original_image也传入用于visualization过程。

sample = {
      common.IMAGE: image,
      common.IMAGE_NAME: image_name,
      common.HEIGHT: height,
      common.WIDTH: width
  }
  if label is not None:
    sample[common.LABEL] = label

  if not is_training:
    sample[common.ORIGINAL_IMAGE] = original_image,
    num_threads = 1

(7)最后,调用tf.train.batch(),返回一个samples。

  return tf.train.batch(
      sample,
      batch_size=batch_size,
      num_threads=num_threads,
      capacity=32 * batch_size,
      allow_smaller_final_batch=not is_training,
      dynamic_pad=True)

4. 最后,在train.py中,调用slim.prefetch_queue.prefetch_queue()方法,生成输入队列

inputs_queue = prefetch_queue.prefetch_queue(
          samples, capacity=128 * config.num_clones)

至此,读取数据集的过程结束。

 

 

声明

1. 本文为作者原创,如需转载,请注明本文链接和作者ID:superkoma。

2. 创作本文目的是为了理解DeepLab v3读取数据集的主要流程,方便在自己的数据集上进行训练和验证测试等。一般大家会将自己的数据集转为TFRecord格式以满足输入要求,但是另一种思路是修改其源码读取数据集的部分,使其能够直接从一个包含图像路径的list中直接读取。作者后续会推出修改方式。

你可能感兴趣的:(深度学习)