DeepLab-V3代码分析(二)

https://github.com/rishizek/tensorflow-deeplab-v3

文章目录

  • 三、DeepLab-V3 模型代码
    • 1. 空洞卷积
    • 2. 生成模型
    • 3. 导入、优化
  • 四、训练代码

三、DeepLab-V3 模型代码

deeplab_model.py

1. 空洞卷积

DeepLab-V3代码分析(二)_第1张图片

def atrous_spatial_pyramid_pooling(inputs, output_stride, batch_norm_decay, is_training, depth=256):
  """Atrous Spatial Pyramid Pooling.

  Args:
    inputs: A tensor of size [batch, height, width, channels].
    output_stride: The ResNet unit's stride. Determines the rates for atrous convolution.
      the rates are (6, 12, 18) when the stride is 16, and doubled when 8.
      ( ResNet下采样步长小了,特征图就大了,想得到一样的尺寸,就得让空洞比例大 )
    batch_norm_decay: The moving average decay when estimating layer activation
      statistics in batch normalization.
    is_training: A boolean denoting whether the input is for training.
    depth: The depth of the ResNet unit output.

  Returns:
    The atrous spatial pyramid pooling output.
  """
  with tf.variable_scope("aspp"):
    if output_stride not in [8, 16]:
      raise ValueError('output_stride must be either 8 or 16.')

    atrous_rates = [6, 12, 18]
    if output_stride == 8:
      atrous_rates = [2*rate for rate in atrous_rates]

    with tf.contrib.slim.arg_scope(resnet_v2.resnet_arg_scope(batch_norm_decay=batch_norm_decay)):
      with arg_scope([layers.batch_norm], is_training=is_training):
        inputs_size = tf.shape(inputs)[1:3]
        # (a) one 1x1 convolution and three 3x3 convolutions with rates = (6, 12, 18) when output stride = 16.
        # the rates are doubled when output stride = 8.
        conv_1x1 = layers_lib.conv2d(inputs, depth, [1, 1], stride=1, scope="conv_1x1")
        conv_3x3_1 = resnet_utils.conv2d_same(inputs, depth, 3, stride=1, rate=atrous_rates[0], scope='conv_3x3_1')
        conv_3x3_2 = resnet_utils.conv2d_same(inputs, depth, 3, stride=1, rate=atrous_rates[1], scope='conv_3x3_2')
        conv_3x3_3 = resnet_utils.conv2d_same(inputs, depth, 3, stride=1, rate=atrous_rates[2], scope='conv_3x3_3')

        # (b) the image-level features
        with tf.variable_scope("image_level_features"):
          # global average pooling
          # tf.reduce_mean: 计算tensor某一维度上的的平均值,主要用作降维或者计算平均值。这里是平均池化
          image_level_features = tf.reduce_mean(inputs, [1, 2], name='global_average_pooling', keepdims=True)
          # 1x1 convolution with 256 filters( and batch normalization)
          image_level_features = layers_lib.conv2d(image_level_features, depth, [1, 1], stride=1, scope='conv_1x1')
          # bilinearly upsample features
          image_level_features = tf.image.resize_bilinear(image_level_features, inputs_size, name='upsample')

        # 1x1,三种3x3,和平均池化层
        net = tf.concat([conv_1x1, conv_3x3_1, conv_3x3_2, conv_3x3_3, image_level_features], axis=3, name='concat')
        net = layers_lib.conv2d(net, depth, [1, 1], stride=1, scope='conv_1x1_concat')

        return net

2. 生成模型

def deeplab_v3_generator(num_classes,
                         output_stride,
                         base_architecture,
                         pre_trained_model,
                         batch_norm_decay,
                         data_format='channels_last'):

3. 导入、优化

def deeplabv3_model_fn(features, labels, mode, params):

四、训练代码

train.py

def get_filenames(is_training, data_dir):
	"""Return a list of filenames."""
def parse_record(raw_record):
	"""Parse PASCAL image and label from a tf record."""
def preprocess_image(image, label, is_training):
	"""Preprocess a single image of layout [height, width, depth]."""
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
	"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset."""
def main(unused_argv):

你可能感兴趣的:(DL)