deeplabv3+二:详细代码解读 data generator 数据生成器

3+支持三种数据库,voc2012,cityscapes,ade20k,

代码文件夹

-deeplab

    -datasets

         -data_generator.py

在开始之前,始终记住,网络模型的输入是非常简单的image,规格化到[-1,1]或[0,1],或者数据扩增(水平翻转,随机裁剪,明暗变化,模糊),以及一个实施了相同数据扩增的label(毕竟需要pixel对上),test的话只需要一个image。是非常简单的数据格式,也许程序员会为了存储的压缩量以及读取处理的速度(指的就是使用tf.example 与 tf.record)写复杂的代码,但是最终的结果始终都是很简单的。

觉得自己一定要先搞清楚tf.example 与tf.record:https://zhuanlan.zhihu.com/p/33223782

 

目录

数据库分析

代码重点类Dataset

1.方法_parse_function()

2. 方法_preprocess_image()

2.1 input_preprocess的preprocess_image_and_label方法介绍

3.方法 _get_all_files(self):

4.方法 get_one_shot_iterator(self)

Class TFRecordDataset

代码使用是在train.py里面:


代码:先放代码,你可以尝试自己看,看得懂就不用往下翻浪费时间了。

# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for providing semantic segmentaion data.

The SegmentationDataset class provides both images and annotations (semantic
segmentation and/or instance segmentation) for TensorFlow. Currently, we
support the following datasets:

1. PASCAL VOC 2012 (http://host.robots.ox.ac.uk/pascal/VOC/voc2012/).

PASCAL VOC 2012 semantic segmentation dataset annotates 20 foreground objects
(e.g., bike, person, and so on) and leaves all the other semantic classes as
one background class. The dataset contains 1464, 1449, and 1456 annotated
images for the training, validation and test respectively.

2. Cityscapes dataset (https://www.cityscapes-dataset.com)

The Cityscapes dataset contains 19 semantic labels (such as road, person, car,
and so on) for urban street scenes.

3. ADE20K dataset (http://groups.csail.mit.edu/vision/datasets/ADE20K)

The ADE20K dataset contains 150 semantic labels both urban street scenes and
indoor scenes.

References:
  M. Everingham, S. M. A. Eslami, L. V. Gool, C. K. I. Williams, J. Winn,
  and A. Zisserman, The pascal visual object classes challenge a retrospective.
  IJCV, 2014.

  M. Cordts, M. Omran, S. Ramos, T. Rehfeld, M. Enzweiler, R. Benenson,
  U. Franke, S. Roth, and B. Schiele, "The cityscapes dataset for semantic urban
  scene understanding," In Proc. of CVPR, 2016.

  B. Zhou, H. Zhao, X. Puig, S. Fidler, A. Barriuso, A. Torralba, "Scene Parsing
  through ADE20K dataset", In Proc. of CVPR, 2017.
"""

import collections
import os
import tensorflow as tf
from deeplab import common
from deeplab import input_preprocess

# Named tuple to describe the dataset properties.
DatasetDescriptor = collections.namedtuple(
    'DatasetDescriptor',
    [
        'splits_to_sizes',  # Splits of the dataset into training, val and test.
        'num_classes',  # Number of semantic classes, including the
                        # background class (if exists). For example, there
                        # are 20 foreground classes + 1 background class in
                        # the PASCAL VOC 2012 dataset. Thus, we set
                        # num_classes=21.
        'ignore_label',  # Ignore label value.
    ])

_CITYSCAPES_INFORMATION = DatasetDescriptor(
    splits_to_sizes={
        'train': 2975,
        'val': 500,
    },
    num_classes=19,
    ignore_label=255,
)

_PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor(
    splits_to_sizes={
        'train': 1464,
        'train_aug': 10582,
        'trainval': 2913,
        'val': 1449,
    },
    num_classes=21,
    ignore_label=255,
)

_ADE20K_INFORMATION = DatasetDescriptor(
    splits_to_sizes={
        'train': 20210,  # num of samples in images/training
        'val': 2000,  # num of samples in images/validation
    },
    num_classes=151,
    ignore_label=0,
)

_DATASETS_INFORMATION = {
    'cityscapes': _CITYSCAPES_INFORMATION,
    'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
    'ade20k': _ADE20K_INFORMATION,
}

# Default file pattern of TFRecord of TensorFlow Example.
_FILE_PATTERN = '%s-*'


def get_cityscapes_dataset_name():
  return 'cityscapes'


class Dataset(object):
  """Represents input dataset for deeplab model."""

  def __init__(self,
               dataset_name,
               split_name,
               dataset_dir,
               batch_size,
               crop_size,
               min_resize_value=None,
               max_resize_value=None,
               resize_factor=None,
               min_scale_factor=1.,
               max_scale_factor=1.,
               scale_factor_step_size=0,
               model_variant=None,
               num_readers=1,
               is_training=False,
               should_shuffle=False,
               should_repeat=False):
    """Initializes the dataset.

    Args:
      dataset_name: Dataset name.
      split_name: A train/val Split name.
      dataset_dir: The directory of the dataset sources.
      batch_size: Batch size.
      crop_size: The size used to crop the image and label.
      min_resize_value: Desired size of the smaller image side.
      max_resize_value: Maximum allowed size of the larger image side.
      resize_factor: Resized dimensions are multiple of factor plus one.
      min_scale_factor: Minimum scale factor value.
      max_scale_factor: Maximum scale factor value.
      scale_factor_step_size: The step size from min scale factor to max scale
        factor. The input is randomly scaled based on the value of
        (min_scale_factor, max_scale_factor, scale_factor_step_size).
      model_variant: Model variant (string) for choosing how to mean-subtract
        the images. See feature_extractor.network_map for supported model
        variants.
      num_readers: Number of readers for data provider.
      is_training: Boolean, if dataset is for training or not.
      should_shuffle: Boolean, if should shuffle the input data.
      should_repeat: Boolean, if should repeat the input data.

    Raises:
      ValueError: Dataset name and split name are not supported.
    """
    if dataset_name not in _DATASETS_INFORMATION:
      raise ValueError('The specified dataset is not supported yet.')
    self.dataset_name = dataset_name

    splits_to_sizes = _DATASETS_INFORMATION[dataset_name].splits_to_sizes

    if split_name not in splits_to_sizes:
      raise ValueError('data split name %s not recognized' % split_name)

    if model_variant is None:
      tf.logging.warning('Please specify a model_variant. See '
                         'feature_extractor.network_map for supported model '
                         'variants.')

    self.split_name = split_name
    self.dataset_dir = dataset_dir
    self.batch_size = batch_size
    self.crop_size = crop_size
    self.min_resize_value = min_resize_value
    self.max_resize_value = max_resize_value
    self.resize_factor = resize_factor
    self.min_scale_factor = min_scale_factor
    self.max_scale_factor = max_scale_factor
    self.scale_factor_step_size = scale_factor_step_size
    self.model_variant = model_variant
    self.num_readers = num_readers
    self.is_training = is_training
    self.should_shuffle = should_shuffle
    self.should_repeat = should_repeat

    self.num_of_classes = _DATASETS_INFORMATION[self.dataset_name].num_classes
    self.ignore_label = _DATASETS_INFORMATION[self.dataset_name].ignore_label

  def _parse_function(self, example_proto):
    """Function to parse the example proto.

    Args:
      example_proto: Proto in the format of tf.Example.

    Returns:
      A dictionary with parsed image, label, height, width and image name.

    Raises:
      ValueError: Label is of wrong shape.
    """

    # Currently only supports jpeg and png.
    # Need to use this logic because the shape is not known for
    # tf.image.decode_image and we rely on this info to
    # extend label if necessary.
    def _decode_image(content, channels):
      return tf.cond(
          tf.image.is_jpeg(content),
          lambda: tf.image.decode_jpeg(content, channels),
          lambda: tf.image.decode_png(content, channels))

    features = {
        'image/encoded':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/filename':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format':
            tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/height':
            tf.FixedLenFeature((), tf.int64, default_value=0),
        'image/width':
            tf.FixedLenFeature((), tf.int64, default_value=0),
        'image/segmentation/class/encoded':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/segmentation/class/format':
            tf.FixedLenFeature((), tf.string, default_value='png'),
    }

    parsed_features = tf.parse_single_example(example_proto, features)

    image = _decode_image(parsed_features['image/encoded'], channels=3)

    label = None
    if self.split_name != common.TEST_SET:
      label = _decode_image(
          parsed_features['image/segmentation/class/encoded'], channels=1)

    image_name = parsed_features['image/filename']
    if image_name is None:
      image_name = tf.constant('')

    sample = {
        common.IMAGE: image,
        common.IMAGE_NAME: image_name,
        common.HEIGHT: parsed_features['image/height'],
        common.WIDTH: parsed_features['image/width'],
    }

    if label is not None:
      if label.get_shape().ndims == 2:
        label = tf.expand_dims(label, 2)
      elif label.get_shape().ndims == 3 and label.shape.dims[2] == 1:
        pass
      else:
        raise ValueError('Input label shape must be [height, width], or '
                         '[height, width, 1].')

      label.set_shape([None, None, 1])

      sample[common.LABELS_CLASS] = label

    return sample

  def _preprocess_image(self, sample):
    """Preprocesses the image and label.

    Args:
      sample: A sample containing image and label.

    Returns:
      sample: Sample with preprocessed image and label.

    Raises:
      ValueError: Ground truth label not provided during training.
    """
    image = sample[common.IMAGE]
    label = sample[common.LABELS_CLASS]

    original_image, image, label = input_preprocess.preprocess_image_and_label(
        image=image,
        label=label,
        crop_height=self.crop_size[0],
        crop_width=self.crop_size[1],
        min_resize_value=self.min_resize_value,
        max_resize_value=self.max_resize_value,
        resize_factor=self.resize_factor,
        min_scale_factor=self.min_scale_factor,
        max_scale_factor=self.max_scale_factor,
        scale_factor_step_size=self.scale_factor_step_size,
        ignore_label=self.ignore_label,
        is_training=self.is_training,
        model_variant=self.model_variant)

    sample[common.IMAGE] = image

    if not self.is_training:
      # Original image is only used during visualization.
      sample[common.ORIGINAL_IMAGE] = original_image

    if label is not None:
      sample[common.LABEL] = label

    # Remove common.LABEL_CLASS key in the sample since it is only used to
    # derive label and not used in training and evaluation.
    sample.pop(common.LABELS_CLASS, None)

    return sample

  def get_one_shot_iterator(self):
    """Gets an iterator that iterates across the dataset once.

    Returns:
      An iterator of type tf.data.Iterator.
    """

    files = self._get_all_files()

    dataset = (
        tf.data.TFRecordDataset(files, num_parallel_reads=self.num_readers)
        .map(self._parse_function, num_parallel_calls=self.num_readers)
        .map(self._preprocess_image, num_parallel_calls=self.num_readers))

    if self.should_shuffle:
      dataset = dataset.shuffle(buffer_size=100)

    if self.should_repeat:
      dataset = dataset.repeat()  # Repeat forever for training.
    else:
      dataset = dataset.repeat(1)

    dataset = dataset.batch(self.batch_size).prefetch(self.batch_size)
    return dataset.make_one_shot_iterator()

  def _get_all_files(self):
    """Gets all the files to read data from.

    Returns:
      A list of input files.
    """
    file_pattern = _FILE_PATTERN
    file_pattern = os.path.join(self.dataset_dir,
                                file_pattern % self.split_name)
    return tf.gfile.Glob(file_pattern)

数据库分析

声明了:namedtuple,新建了一种数据类型,格式:一般来讲splits_to_sizes这个属性是数据的图片的split,集合,train,train_aug,trainval,val这种。可以告诉model你有多少图片在训练模式,或者是与训练模式要完成。

DatasetDescriptor = collections.namedtuple(
    'DatasetDescriptor',
    [
        'splits_to_sizes',  # Splits of the dataset into training, val and test.
        'num_classes',  # Number of semantic classes, including the
                        # background class (if exists). For example, there
                        # are 20 foreground classes + 1 background class in
                        # the PASCAL VOC 2012 dataset. Thus, we set
                        # num_classes=21.
        'ignore_label',  # Ignore label value.
    ])

举个例子:voc数据用这个数据类型DatasetDescriptor来声明这个数据库的一些信息训练图片1464张,val1449张,一共有21类,包括背景类,255是那个白边,也就是未标注类,所以是不计入损失函数的。

_PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor(
    splits_to_sizes={
        'train': 1464,
        'train_aug': 10582,
        'trainval': 2913,
        'val': 1449,
    },
    num_classes=21,
    ignore_label=255,
)

代码中一共有三个数据库建立了DataDescriptor 信息,字典如下

_DATASETS_INFORMATION = {
    'cityscapes': _CITYSCAPES_INFORMATION,
    'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
    'ade20k': _ADE20K_INFORMATION,
}

如果提供的数据库不属于已建立的,会出错,所以自己单独的数据需要建立相同得到generator。

代码重点类Dataset

这个代码中最重要的是Dataset这个类。

 

这个类初始化需要传入很多参数,代码的注释写的很详细。其中model_variant必须传入,否则就会raise 错误。

model_variant这里的解释:

model_variant: Model variant (string) for choosing how to mean-subtract
  the images. See feature_extractor.network_map for supported model
  variants.

有时候数据处理需要减去mean的值那么对于图片就是255/2,把数据处理成[-1,1]

毕竟如果不减去那就是[0,1]规格化,feature_extractor在core文件夹下。

 

还有另外的两个属性初始化来自于数据库信息,DataGenerator

    self.num_of_classes = _DATASETS_INFORMATION[self.dataset_name].num_classes
    self.ignore_label = _DATASETS_INFORMATION[self.dataset_name].ignore_label

再来说方法:

  • _parse_function(self,example_proto)
  • _preprocess_image(self, sample)
  • get_one_shot_iterator(self)
  • _get_all_files(self)

1.方法_parse_function()

其中_parse_function 是解析tf.example 一种tf特有的数据格式到字典的。返回的是图片的所有信息。

"""Function to parse the example proto.

Args:
  example_proto: Proto in the format of tf.Example.

Returns:
  A dictionary with parsed image, label, height, width and image name.

Raises:
  ValueError: Label is of wrong shape.
"""

该函数还包含了另一个函数:

    def _decode_image(content, channels):
      return tf.cond(
          tf.image.is_jpeg(content),
          lambda: tf.image.decode_jpeg(content, channels),
          lambda: tf.image.decode_png(content, channels))

将encode的图片解析uint8的tensor。所以网络的输入依旧是熟知的数据类型float,虽然tf官方为了目前没有看到作用的原因转换了下数据结构(不转换正常yeild),最终的generator又转了回来。

features提供了tf.example(可以认为也像字典那样存储,有key值)转成正常python字典的键值:

features = {
        'image/encoded':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/filename':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format':
            tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/height':
            tf.FixedLenFeature((), tf.int64, default_value=0),
        'image/width':
            tf.FixedLenFeature((), tf.int64, default_value=0),
        'image/segmentation/class/encoded':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/segmentation/class/format':
            tf.FixedLenFeature((), tf.string, default_value='png'),
    }

函数的参数example_proto就是待转换的被存储成example格式的图片数据,

利用

parsed_features = tf.parse_single_example(example_proto, features)

tf的函数parse_single_example就是将tf.example按照features格式解析的函数,那么还有一个问题,那么整体的打包解包格式到底是什么样的。

比如图片segmentation任务,feature包括image以及他的label作为网络输入。

那么一个uint8的图片会收到怎么样的处理。

打包过程:uint8->encoded_image->tf.example->存储tf.record

解包过程:读取tf.record->tf.example->encoded_image->uint8

所以parse_single_example 的输出是一个encode_image,需要使用上面提到的_decode_iamge转换。

这就是image的处理

对于label,测试数据是没有label的所以label=None

如果不是测试数据,那么用和image一样的方法来解析label, 如果你注意到了common,我直接把common.TEST_SET放这:

# Test set name.
TEST_SET = 'test'

明白了吧。

    if self.split_name != common.TEST_SET:
      label = _decode_image(
          parsed_features['image/segmentation/class/encoded'], channels=1)

就是比较奇怪,为什么不将_decode_image直接封装到decode_image中,而是要用户自己先做判断再挑选方法。可能下面的注释是解释,对于_decode_image的这段叙述:

# Currently only supports jpeg and png.
# Need to use this logic because the shape is not known for
# tf.image.decode_image and we rely on this info to
# extend label if necessary.

说是因为image shape未知,目前还没看到为什么用了这个条件语句shape就已知了。

另外features明明包含了

'image/height':
    tf.FixedLenFeature((), tf.int64, default_value=0),
'image/width':
    tf.FixedLenFeature((), tf.int64, default_value=0),

这里有个比较重要的东西,就是label的shape

label必须是三维的[512,512,1],也就是说:

'Input label shape must be [height, width], or '
                 '[height, width, 1].'
    if label is not None:
      if label.get_shape().ndims == 2:
        label = tf.expand_dims(label, 2)
      elif label.get_shape().ndims == 3 and label.shape.dims[2] == 1:
        pass
      else:
        raise ValueError('Input label shape must be [height, width], or '
                         '[height, width, 1].')

最终返回的sample是一个字典:其中common表示的string我都直接给你显示出来了。common没什么神奇的,就是一个常数映射IMAGE='image'

这样写代码的时候能够保持一致,并且使用Tab补全。

    sample = {
        common.IMAGE: image,
        common.IMAGE_NAME: image_name,
        common.HEIGHT: parsed_features['image/height'],
        common.WIDTH: parsed_features['image/width'],
    }
      sample[common.LABELS_CLASS] = label
##import common
# Semantic segmentation item names.
LABELS_CLASS = 'labels_class'
IMAGE = 'image'
HEIGHT = 'height'
WIDTH = 'width'
IMAGE_NAME = 'image_name'
LABEL = 'label'
ORIGINAL_IMAGE = 'original_image'

2. 方法_preprocess_image()

"""Preprocesses the image and label.

Args:
  sample: A sample containing image and label.

Returns:
  sample: Sample with preprocessed image and label.

Raises:
  ValueError: Ground truth label not provided during training.
"""

该方法预处理image和label

sample就是我们刚才使用_parse_function返回的那个sample字典

首先将待预处理的label与image从sample中取出:

    image = sample[common.IMAGE]
    label = sample[common.LABELS_CLASS]

然后调用一个函数直接处理image和label

    original_image, image, label = input_preprocess.preprocess_image_and_label(
        image=image,
        label=label,
        crop_height=self.crop_size[0],
        crop_width=self.crop_size[1],
        min_resize_value=self.min_resize_value,
        max_resize_value=self.max_resize_value,
        resize_factor=self.resize_factor,
        min_scale_factor=self.min_scale_factor,
        max_scale_factor=self.max_scale_factor,
        scale_factor_step_size=self.scale_factor_step_size,
        ignore_label=self.ignore_label,
        is_training=self.is_training,
        model_variant=self.model_variant)

2.1 input_preprocess的preprocess_image_and_label方法介绍

"""Preprocesses the image and label.

Args:
  image: Input image.
  label: Ground truth annotation label.
  crop_height: The height value used to crop the image and label.
  crop_width: The width value used to crop the image and label.
  min_resize_value: Desired size of the smaller image side.
  max_resize_value: Maximum allowed size of the larger image side.
  resize_factor: Resized dimensions are multiple of factor plus one.
  min_scale_factor: Minimum scale factor value.
  max_scale_factor: Maximum scale factor value.
  scale_factor_step_size: The step size from min scale factor to max scale
    factor. The input is randomly scaled based on the value of
    (min_scale_factor, max_scale_factor, scale_factor_step_size).
  ignore_label: The label value which will be ignored for training and
    evaluation.
  is_training: If the preprocessing is used for training or not.
  model_variant: Model variant (string) for choosing how to mean-subtract the
    images. See feature_extractor.network_map for supported model variants.

Returns:
  original_image: Original image (could be resized).
  processed_image: Preprocessed image.
  label: Preprocessed ground truth segmentation label.

Raises:
  ValueError: Ground truth label not provided during training.
"""

这个方法包含了图片大小的转换,随机放缩图片,以均值扩增,随机裁剪,随机左右翻转。更多的细节需要结合试验,才能知道具体的预处理效果。这里我们只需要知道返回的image , label都是处理好的就可以了。

 

  注意返回值中的original_image并不是原始大小的,可能会被resize,这个图片只是在visualize展示用的。LABEL_CLASS这个键值由于是预处理的输入,就是个梯子,真正的label键值是LABEL,所以最终sample的内容

    sample[common.IMAGE] = image

    if not self.is_training:
      # Original image is only used during visualization.
      sample[common.ORIGINAL_IMAGE] = original_image

    if label is not None:
      sample[common.LABEL] = label
    # Remove common.LABEL_CLASS key in the sample since it is only used to
    # derive label and not used in training and evaluation.这里就把LABELS_CLASS踢出去了
    sample.pop(common.LABELS_CLASS, None)

 

3.方法 _get_all_files(self):

"""Gets all the files to read data from.

Returns:
  A list of input files.
"""
    file_pattern = _FILE_PATTERN
    file_pattern = os.path.join(self.dataset_dir,
                                file_pattern % self.split_name)
    return tf.gfile.Glob(file_pattern)
Glob的定义是

A list of strings containing filenames that match the given pattern(s).

一般情况下,代码是这么写的:

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)

这个调试一下会看的更清楚。这里的解释也没什么内容:https://zhuanlan.zhihu.com/p/31536538

就注释来看,这里获取了所有的数据地址,我们截图中的那些.tfrecord文件。但还是没有读取。

4.方法 get_one_shot_iterator(self)

"""Gets an iterator that iterates across the dataset once.

Returns:
  An iterator of type tf.data.Iterator.
"""

这个方法建立了一个迭代器。类型是

tf.data.Iterator.

 

    files = self._get_all_files()

    dataset = (
        tf.data.TFRecordDataset(files, num_parallel_reads=self.num_readers)
        .map(self._parse_function, num_parallel_calls=self.num_readers)
        .map(self._preprocess_image, num_parallel_calls=self.num_readers))

TFRecordDataset官方解释:https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset

Class TFRecordDataset

Dataset comprising records from one or more TFRecord files.

那么Dataset又是什么,map()和他怎么用的?注意这里的Dataset指的是tf.data.Dataset,而不是我们自己生成的。很简单,TFRecordDataset建立的Dataset实例。我们这个压根还没有建立实例呢。建立实例要到最后。你看我从来没提过我们自己建立的Dataset这个类对吧,因为目前还没用到。

Dataset是个实例,其定义,详细可以到官网:

https://www.tensorflow.org/guide/datasets?hl=zh-CN查看,是个很常用的官方提供的数据处理Class。

简单来讲Dataset就是一个数据库,包含了若干个样本,每个样本包含了image和label(对于图片分割任务来讲)

那个TFRecordDataset就是一个reord类型的数据库。而map方法就是将里面的function应用于Dataset当中的每个元素(样本)。

 

那么这个一长串的代码到底干了什么。

tf.data.TFRecordDataset(files, num_parallel_reads=self.num_readers)

使用TFRecordDataset 的init方法,从files包含的一个list文件名中建立了一个Dataset,只不过里面存的都是tfrecord类型的元素。

 

然后用

.map(self._parse_function, num_parallel_calls=self.num_readers)

指定的_parse_function方法将tfrecord类型的元素转换成字典sample,此刻的image,以及label都已经转换成了常见的数据类型。

最后用:

map(self._preprocess_image, num_parallel_calls=self.num_readers))

指定的_preprocess_image方法将sample字典中的image label做预处理。

 

最后返回了一个迭代器。

代码使用是在train.py里面:

      dataset = data_generator.Dataset(
          dataset_name=FLAGS.dataset,
          split_name=FLAGS.train_split,
          dataset_dir=FLAGS.dataset_dir,
          batch_size=clone_batch_size,
          crop_size=[int(sz) for sz in FLAGS.train_crop_size],
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          model_variant=FLAGS.model_variant,
          num_readers=2,
          is_training=True,
          should_shuffle=True,
          should_repeat=True)

      train_tensor, summary_op = _train_deeplab_model(
          dataset.get_one_shot_iterator(), dataset.num_of_classes,
          dataset.ignore_label)
def _train_deeplab_model(iterator, num_of_classes, ignore_label):

实例化一个dataset,在利用get_one_shot_iterator()方法返回一个迭代器,提供了类别数,比如voc2012就是21类,

以及忽略的标签,比如voc2012就是255.

总结

这个类的名字Dataset不是随便取的,这个类行使的功能如同Dataset提供了一个sequential,就是一串的可以不断喂给网络的不消耗大量内存和cpu的数据流,也就是最终我们说的迭代器iterator。

传递给训练function。

 

你可能感兴趣的:(tensorflow,人工智能)