Semantic Segmentation--ICNet for Real-Time Semantic Segmentation on High-Resolution Images论文解读

ICNet

ICNet for Real-Time Semantic Segmentation on High-Resolution Images

原文地址:ICNet

代码:

  • github-Caffe
  • TensorFlow

Abstract

ICNet是一个基于PSPNet的实时语义分割网络,设计目的是减少PSPNet推断时期的耗时,论文对PSPNet做了深入分析,在PSPNet的基础上引入级联特征融合模块,实现快速且高质量的分割模型。论文报告了在Cityscape上的表现。


Introduction

现有的语义分割模型在Cityscape上的表现:

Semantic Segmentation--ICNet for Real-Time Semantic Segmentation on High-Resolution Images论文解读_第1张图片
可以看到,许多高质量的分割模型的推理速度远远达不到实时要求。本文的目的是在不牺牲过多分割质量的前提下提高模型推理速度,达到实时要求。本文是在PSPNet的基础上,找了一个accuracy和speed上的平衡点。

论文的主要贡献在于:

  • 综合低分辨率图像的处理速度和高分辨率图像的推断质量,提出图像级联框架逐步细化分割预测
  • 本文提出的ICNet达到了5x以上的加速,并减少了5x倍以上的内存消耗
  • ICNet可以在 1024 × 2048 1024×2048 1024×2048分辨率下保持30fps运行

Related Work

最近CNN表现出比传统方法强大的特征提取能力。

  • 高质量的语义分割模型: 先驱工作FCN将FC层换为卷积层; DeepLab等使用空洞卷积(dilated convolution);Encoder-Decoder结构融合高层的语义和底层的细节;也有使用CRF、MRF模拟空间关系;PSPNet采用空间上的金字塔池化结构。这些方法对于提高性能有效,但不能用于实时系统。

  • 快速的语义分割模型:SegNet放弃层信息来提速;ENet是一个轻量级网络,这些方法虽然快,但是性能差

  • 视频分割模型: 视频中包含大量冗余信息,可利用减少计算量。etc…

论文给出了一个快速的语义分割的层次结构,采用级联图像作为输入加速推理,构建一个实时分割系统。


Speed Analysis

ICNet是基于PSPNet的,我们先分析PSPNet,PSPNet的结构如下:

Semantic Segmentation--ICNet for Real-Time Semantic Segmentation on High-Resolution Images论文解读_第2张图片

实时预算(time budget)

先看看图像分辨率对PSPNet的性能影响,下图显示两个分辨率下时间成本:

Semantic Segmentation--ICNet for Real-Time Semantic Segmentation on High-Resolution Images论文解读_第3张图片
蓝色的分辨率为 1024 × 2048 1024×2048 1024×2048;绿色的分辨率为 512 × 512 512×512 512×512

上图显示了多个信息:

  • 不同分辨率下速度相差很大,呈平方趋势增加
  • 网络的宽度越大速度越慢。
  • 核数量越多速度越慢。例如:虽然stage4和5的分辨率一致,但速度相差很大,因为5比4核数量多1倍

提速方向(Intuitive Speedup)

输入降采样
依据上面的分析,半分辨率的推断时间为全分辨率的1/4.测试不同分辨率下输入下的预测情况。一个简单的测试方法使用1/2,1/4的输入,将输出上采样回原来的大小。实验如下:

Semantic Segmentation--ICNet for Real-Time Semantic Segmentation on High-Resolution Images论文解读_第4张图片
如图所示的几个缺点。在缩放为.25的情况下,虽然推断时间大大简短,但是预测结果非常粗糙,丢失了很多小但重要的细节。缩放0.5相对来说好了很多,但依旧丢失了很多细节,并且最麻烦的是推理速度达不到实时要求了。

特征降采样
输入能降采样,自然特征也可以降采样。这里以1:8,1:16,1:32的比例测试PSPNet50,结果如下:

Semantic Segmentation--ICNet for Real-Time Semantic Segmentation on High-Resolution Images论文解读_第5张图片
较小的特征图可以以牺牲准确度换取更快的推断,但考虑到使用1:32(132ms)依然达不到实时要求.

模型压缩
减少网络的复杂度,有一个直接的方法就是修正每层的核数量,论文测试了一些有效的模型压缩策略。即使只保留四分之一的核,推断时间还是很长。并且准确率大大降低了。


Architecture

总结一下前面速度分析的结果,一系列的优化方法:

  • 输入分辨率:降低输入分辨率能都大幅度的加速,但同时会让预测非常模糊
  • 降低下采样特征:降低下采样可以加速但同时会降低准确率
  • 压缩模型:压缩训练好的模型,通过减轻模型达到加速效果,可惜实验效果不佳

ICNet总结了上述几个问题,提出了一个综合性的方法:使用低分辨率加速捕捉语义,使用高分辨率获取细节,使用级联网络结合,在限制的时间内获得有效的结果。

模型结构如下:
Semantic Segmentation--ICNet for Real-Time Semantic Segmentation on High-Resolution Images论文解读_第6张图片
图片分为1,1/2,1/4这三个尺度分三路送到模型中(实际代码和这个描述不一致,见后面代码分析),三个分支介绍如下:

分支 过程 耗时
低分辨率 在中分辨率的1/16输出的基础上,再缩放到1/32.经过卷积后,然后使用几个dilated convolution扩展接受野但不缩小尺寸,最终以原图的1/32大小输出feature map。 虽然层数较多,但是分辨率低,速度快,且与分支二共享一部分权重
中分辨率 以原图的1/2的分辨率作为输入,经过卷积后以1/8缩放,得到原图的1/16大小feature map,再将低分辨率分支的输出feature map通过CFF(cascade feature fusion)单元相融合得到最终输出。值得注意的是:低分辨率和中分辨率的卷积参数是共享的。 有17个卷积层,与分支一共享一部分权重,与分支一一起一共耗时6ms
高分辨率 原图输入,经过卷积后以1/8缩放,得到原图的1/8大小的feature map,再将中分辨率处理后的输出通过CFF单元融合 有3个卷积层,虽然分辨率高,因为少,耗时为9ms

对于每个分支的输出特征,首先会上采样2倍做输出,在训练的时候,会以Ground truth的1/16、1/8/、1/4来指导各个分支的训练,这样的辅助训练使得梯度优化更为平滑,便于训练收敛,随着每个分支学习能力的增强,预测没有被任何一个分支主导。利用这样的渐变的特征融合和级联引导结构可以产生合理的预测结果。

ICNet使用低分辨率完成语义分割,使用高分辨率帮助细化结果。在结构上,产生的feature大大减少,同时仍然保持必要的细节。

不同分支的预测效果如下:

Semantic Segmentation--ICNet for Real-Time Semantic Segmentation on High-Resolution Images论文解读_第7张图片
可以看到第三个分支输出效果无疑是最好的。在测试时,只保留第三分支的结果。

CFF单元

在ICNet中,分支之间的融合是通过CFF模块完成的。结构如下:

Semantic Segmentation--ICNet for Real-Time Semantic Segmentation on High-Resolution Images论文解读_第8张图片
将低分辨率的图片上采样后使用空洞卷积(dilated conv),扩大上采样结果的感受野范围。

注意将辅助的标签引导设置为0.4(依据PSPNet的实验结果),即如果下分支的loss L 3 L_3 L3的占比 λ 3 \lambda_3 λ3设置为1的话,则中分支的loss L 2 L_2 L2的占比 λ 2 \lambda_2 λ2设置为0.4,上分支的loss L 1 L_1 L1的占比 λ 1 \lambda_1 λ1设置为0.16。

损失函数和模型压缩

损失函数:
依据不同的分支定义如下: L = λ 1 L 1 + λ 2 L 2 + λ 3 L 3 L=\lambda_1L_1+\lambda_2L_2+\lambda_3L_3 L=λ1L1+λ2L2+λ3L3

依据CFF的设置,下分支的loss L 3 L_3 L3的占比 λ 3 \lambda_3 λ3设置为1的话,则中分支的loss L 2 L_2 L2的占比 λ 2 \lambda_2 λ2设置为0.4,上分支的loss L 1 L_1 L1的占比 λ 1 \lambda_1 λ1设置为0.16。

压缩模型
正如前面所述,压缩模型可能通过减少层的核数,降低模型复杂度。

论文采用的一个简单而有效的办法:渐进式压缩。例如以压缩率1/2为例,我们可以先压缩到3/4,对压缩后的模型进行微调,完成后,再压缩到1/2,再微调。保证压缩稳定进行。

这里采用Pruning filters for efficient convnets的方法,对于每个滤波器计算核权重的L1和,降序排序,删除权重值较小的。


Experiment

实验细节:

项目 设置
平台 Caffe,CUDA7.5 cudnnV5,TitanX单卡
测量推荐时间 Caffe time,100次取均值
batch size 16
学习率 poly策略,基础学习率为0.01,动量0.9
迭代次数 30K
权重衰减 0.0001
数据增强 随机翻转,在0.5到2之间随机放缩

ICNet是从PSPNet修改而来,将PSPNet的池化级联改为了求和,将通道4096减少到2048,改变了PSP模型后的卷积核大小,从 3 × 3 3×3 3×3改为 1 × 1 1×1 1×1,这对结果影响不大,但可以大大节省计算量。

数据集使用的是的Cityscapes,评估标准有mIoU和推断时间.

模型压缩的实验

以PSPNet50为例,直接压缩结果如下表Baseline

mIoU降低了,但推理时间170ms达不到实时要求。这表明只是模型压缩是达不到有良好分割结果的实时性能。对比ICNet,有类似的分割结果,但速度提升了5倍多。

级联结构的有效性实验

测试级联结构的有效性实验,是通过不同分支的输出对比,如下表:

Semantic Segmentation--ICNet for Real-Time Semantic Segmentation on High-Resolution Images论文解读_第9张图片

sub4代表只有低分辨率输入的结果,sub24代表前两个分支,sub124全部分支。注意到全部分支的速度很快,并且性能接近PSPNet了,且能保持30fps。而且内存消耗也明显减少了。

视觉比较

下图是输入和输出之间的比较:

可以看到sub4可以捕捉到大部分的语义了、但因为是低分辨率输入,输出很粗糙。无论是sub4还是sub24都或多或少的丢失了细节。

定量分析

进一步确定每个分支的效率增益,基于连通图定量分析(这个实验没看懂~)

Semantic Segmentation--ICNet for Real-Time Semantic Segmentation on High-Resolution Images论文解读_第10张图片

Cityscapes上的表现

ICNet训练90K后与其他模型比较:

Semantic Segmentation--ICNet for Real-Time Semantic Segmentation on High-Resolution Images论文解读_第11张图片

可以看到ICNet的效果不错


Conclusion

论文在PSPNet的基础上改进出一个ICNet。 核心的思想是利用低分辨率的快速获取语义信息,高分辨率的细节信息。将两者相融合搞出一个折中的模型。


代码分析

这里没有直接分析原版的Caffe版本.选择的是一个比较好理解的TensorFlow版本。

代码框架

这个TensorFlow代码比较有意思,用装饰器做了一个链式模型,我们先看基本的装饰器和NetWork基类架构实现。

代码network包含装饰器等定义:

import numpy as np
import tensorflow as tf
slim = tf.contrib.slim

DEFAULT_PADDING = 'VALID'
DEFAULT_DATAFORMAT = 'NHWC'
layer_name = []
BN_param_map = {'scale':    'gamma',
                'offset':   'beta',
                'variance': 'moving_variance',
                'mean':     'moving_mean'}
                
def layer(op):
    '''定义可组合网络层的装饰器。Decorator for composable network layers.'''

    def layer_decorated(self, *args, **kwargs):
        # 如果没有提供name,则自动配置
        name = kwargs.setdefault('name', self.get_unique_name(op.__name__))
        # 弄清楚该层的输入
        if len(self.terminals) == 0:
            raise RuntimeError('No input variables found for layer %s.' % name)
        elif len(self.terminals) == 1:
            layer_input = self.terminals[0]
        else:
            layer_input = list(self.terminals)
        # 执行对应的操作并输出结果
        layer_output = op(self, layer_input, *args, **kwargs)
        # Add to layer LUT.
        self.layers[name] = layer_output
        layer_name.append(name)
        # 该层输出是下层的输入
        self.feed(layer_output)
        # 返回self,用于链式调用
        return self

    return layer_decorated


class Network(object):

    def __init__(self, inputs, trainable=True, is_training=False, num_classes=21):
        ''' 定义一个NetWork基类 提供必要的方法和层定义 '''
        self.inputs = inputs # 模型input节点
        self.terminals = [] # 当前存在节点
        self.layers = dict(inputs) # Mapping from layer names to layers
        # If true, the resulting variables are set as trainable
        self.is_training = is_training
        self.trainable = trainable
        # Switch variable for dropout
        self.use_dropout = tf.placeholder_with_default(tf.constant(1.0),
                                                       shape=[],
                                                       name='use_dropout')

        self.setup(is_training, num_classes)

    def setup(self, is_training):
        '''Construct the network. '''
        raise NotImplementedError('Must be implemented by the subclass.')

    def load(self, data_path, session, ignore_missing=False):
        '''加载模型权重
        data_path: np序列文件地址
        session: 当前的tensorflow session
        ignore_missing: If true, 忽略序列中缺失层
        '''
        data_dict = np.load(data_path, encoding='latin1').item()
        for op_name in data_dict:
            with tf.variable_scope(op_name, reuse=True):
                for param_name, data in data_dict[op_name].iteritems():
                    try:
                        if 'bn' in op_name:
                            param_name = BN_param_map[param_name]

                        var = tf.get_variable(param_name)
                        session.run(var.assign(data))
                    except ValueError:
                        if not ignore_missing:
                            raise

    def feed(self, *args):
        '''设置一个输入
        The arguments can be either layer names or the actual layers.
        '''
        assert len(args) != 0
        self.terminals = []
        for fed_layer in args:
            if isinstance(fed_layer, str):
                try:
                    fed_layer = self.layers[fed_layer]
                except KeyError:
                    raise KeyError('Unknown layer name fed: %s' % fed_layer)
            self.terminals.append(fed_layer)
        return self

    def get_output(self):
        '''获取模型输出.'''
        return self.terminals[-1]

    def get_unique_name(self, prefix):
        '''Returns an index-suffixed unique name for the given prefix.
        This is used for auto-generating layer names based on the type-prefix.
        '''
        ident = sum(t.startswith(prefix) for t, _ in self.layers.items()) + 1
        return '%s_%d' % (prefix, ident)

    def make_var(self, name, shape):
        '''Creates a new TensorFlow variable.'''
        return tf.get_variable(name, shape, trainable=self.trainable)

    def get_layer_name(self):
        return layer_name
    def validate_padding(self, padding):
        '''判断padding设置是否合法'''
        assert padding in ('SAME', 'VALID')

    @layer
    def zero_padding(self, input, paddings, name):
        '''zero padding 层  '''
        pad_mat = np.array([[0,0], [paddings, paddings], [paddings, paddings], [0, 0]])
        return tf.pad(input, paddings=pad_mat, name=name)

    @layer
    def conv(self,
             input,
             k_h,
             k_w,
             c_o,
             s_h,
             s_w,
             name,
             relu=True,
             padding=DEFAULT_PADDING,
             group=1,
             biased=True):
        ''' conv 层'''
        # Verify that the padding is acceptable
        self.validate_padding(padding)
        # Get the number of channels in the input
        c_i = input.get_shape()[-1]

        convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding,data_format=DEFAULT_DATAFORMAT)
        with tf.variable_scope(name) as scope:
            kernel = self.make_var('weights', shape=[k_h, k_w, c_i, c_o])
            output = convolve(input, kernel)

            if biased:
                biases = self.make_var('biases', [c_o])
                output = tf.nn.bias_add(output, biases)
            if relu:
                output = tf.nn.relu(output, name=scope.name)
            return output

    @layer
    def atrous_conv(self,
                    input,
                    k_h,
                    k_w,
                    c_o,
                    dilation,
                    name,
                    relu=True,
                    padding=DEFAULT_PADDING,
                    group=1,
                    biased=True):
        ''' 空洞卷积 '''
        self.validate_padding(padding)
        # Get the number of channels in the input
        c_i = input.get_shape()[-1]

        convolve = lambda i, k: tf.nn.atrous_conv2d(i, k, dilation, padding=padding)
        with tf.variable_scope(name) as scope:
            kernel = self.make_var('weights', shape=[k_h, k_w, c_i, c_o])
            output = convolve(input, kernel)

            if biased:
                biases = self.make_var('biases', [c_o])
                output = tf.nn.bias_add(output, biases)
            if relu:
                output = tf.nn.relu(output, name=scope.name)
            return output

    @layer
    def relu(self, input, name):
        return tf.nn.relu(input, name=name)

    @layer
    def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
        self.validate_padding(padding)
        return tf.nn.max_pool(input,
                              ksize=[1, k_h, k_w, 1],
                              strides=[1, s_h, s_w, 1],
                              padding=padding,
                              name=name,
                              data_format=DEFAULT_DATAFORMAT)

    @layer
    def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
        self.validate_padding(padding)

        output = tf.nn.avg_pool(input,
                              ksize=[1, k_h, k_w, 1],
                              strides=[1, s_h, s_w, 1],
                              padding=padding,
                              name=name,
                              data_format=DEFAULT_DATAFORMAT)
        return output

    @layer
    def lrn(self, input, radius, alpha, beta, name, bias=1.0):
        return tf.nn.local_response_normalization(input,
                                                  depth_radius=radius,
                                                  alpha=alpha,
                                                  beta=beta,
                                                  bias=bias,
                                                  name=name)

    @layer
    def concat(self, inputs, axis, name):
        return tf.concat(axis=axis, values=inputs, name=name)

    @layer
    def add(self, inputs, name):
        return tf.add_n(inputs, name=name)

    @layer
    def fc(self, input, num_out, name, relu=True):
        with tf.variable_scope(name) as scope:
            input_shape = input.get_shape()
            if input_shape.ndims == 4:
                # The input is spatial. Vectorize it first.
                dim = 1
                for d in input_shape[1:].as_list():
                    dim *= d
                feed_in = tf.reshape(input, [-1, dim])
            else:
                feed_in, dim = (input, input_shape[-1].value)
            weights = self.make_var('weights', shape=[dim, num_out])
            biases = self.make_var('biases', [num_out])
            op = tf.nn.relu_layer if relu else tf.nn.xw_plus_b
            fc = op(feed_in, weights, biases, name=scope.name)
            return fc

    @layer
    def softmax(self, input, name):
        input_shape = map(lambda v: v.value, input.get_shape())
        if len(input_shape) > 2:
            # For certain models (like NiN), the singleton spatial dimensions
            # need to be explicitly squeezed, since they're not broadcast-able
            # in TensorFlow's NHWC ordering (unlike Caffe's NCHW).
            if input_shape[1] == 1 and input_shape[2] == 1:
                input = tf.squeeze(input, squeeze_dims=[1, 2])
            else:        return tf.nn.softmax(input, name)

    @layer
    def batch_normalization(self, input, name, scale_offset=True, relu=False):
        """
        # NOTE: Currently, only inference is supported
        with tf.variable_scope(name) as scope:
            shape = [input.get_shape()[-1]]
            if scale_offset:
                scale = self.make_var('scale', shape=shape)
                offset = self.make_var('offset', shape=shape)
            else:
                scale, offset = (None, None)
            output = tf.nn.batch_normalization(
                input,
                mean=self.make_var('mean', shape=shape),
                variance=self.make_var('variance', shape=shape),
                offset=offset,
                scale=scale,
                # TODO: This is the default Caffe batch norm eps
                # Get the actual eps from parameters
                variance_epsilon=1e-5,
                name=name)
            if relu:
                output = tf.nn.relu(output)
            return output
        """
        output = tf.layers.batch_normalization(
                    input,
                    momentum=0.95,
                    epsilon=1e-5,
                    training=self.is_training,
                    name=name
                )

        if relu:
            output = tf.nn.relu(output)

        return output

    @layer
    def dropout(self, input, keep_prob, name):
        keep = 1 - self.use_dropout + (self.use_dropout * keep_prob)
        return tf.nn.dropout(input, keep, name=name)

    @layer
    def resize_bilinear(self, input, size, name):
        '''双线性插值的放缩'''
        return tf.image.resize_bilinear(input, size=size, align_corners=True, name=name)

    @layer
    def interp(self, input, factor, name):
        '''指定大小输入'''
        ori_h, ori_w = input.get_shape().as_list()[1:3]
        resize_shape = [(int)(ori_h * factor), (int)(ori_w * factor)]

        return tf.image.resize_bilinear(input, size=resize_shape, align_corners=True, name=name)


ICNet构建

有了上面的铺垫,再看model.py文件中关于ICNet_BN的定义。该实现没有做模型压缩的操作,所以在实现的时候只做了一半的kernel.

注意ICNet_BN类需要继承Network类,重写了setup方法.

中分支结构

class ICNet_BN(Network):
    def setup(self, is_training, num_classes):
        (self.feed('data') # feed层
             .interp(factor=0.5, name='data_sub2')
             .conv(3, 3, 32, 2, 2, biased=False, padding='SAME', relu=False, name='conv1_1_3x3_s2')
             .batch_normalization(relu=True, name='conv1_1_3x3_s2_bn')
             .conv(3, 3, 32, 1, 1, biased=False, padding='SAME', relu=False, name='conv1_2_3x3')
             .batch_normalization(relu=True, name='conv1_2_3x3_bn')
             .conv(3, 3, 64, 1, 1, biased=False, padding='SAME', relu=False, name='conv1_3_3x3')
             .batch_normalization(relu=True, name='conv1_3_3x3_bn')
             .max_pool(3, 3, 2, 2, name='pool1_3x3_s2')
             .conv(1, 1, 128, 1, 1, biased=False, relu=False, name='conv2_1_1x1_proj')
             .batch_normalization(relu=False, name='conv2_1_1x1_proj_bn'))

        (self.feed('pool1_3x3_s2')
             .conv(1, 1, 32, 1, 1, biased=False, relu=False, name='conv2_1_1x1_reduce')
             .batch_normalization(relu=True, name='conv2_1_1x1_reduce_bn')
             .zero_padding(paddings=1, name='padding1')
             .conv(3, 3, 32, 1, 1, biased=False, relu=False, name='conv2_1_3x3')
             .batch_normalization(relu=True, name='conv2_1_3x3_bn')
             .conv(1, 1, 128, 1, 1, biased=False, relu=False, name='conv2_1_1x1_increase')
             .batch_normalization(relu=False, name='conv2_1_1x1_increase_bn'))

        (self.feed('conv2_1_1x1_proj_bn',
                   'conv2_1_1x1_increase_bn')
             .add(name='conv2_1')
             .relu(name='conv2_1/relu')
             .conv(1, 1, 32, 1, 1, biased=False, relu=False, name='conv2_2_1x1_reduce')
             .batch_normalization(relu=True, name='conv2_2_1x1_reduce_bn')
             .zero_padding(paddings=1, name='padding2')
             .conv(3, 3, 32, 1, 1, biased=False, relu=False, name='conv2_2_3x3')
             .batch_normalization(relu=True, name='conv2_2_3x3_bn')
             .conv(1, 1, 128, 1, 1, biased=False, relu=False, name='conv2_2_1x1_increase')
             .batch_normalization(relu=False, name='conv2_2_1x1_increase_bn'))

        (self.feed('conv2_1/relu',
                   'conv2_2_1x1_increase_bn')
             .add(name='conv2_2')
             .relu(name='conv2_2/relu')
             .conv(1, 1, 32, 1, 1, biased=False, relu=False, name='conv2_3_1x1_reduce')
             .batch_normalization(relu=True, name='conv2_3_1x1_reduce_bn')
             .zero_padding(paddings=1, name='padding3')
             .conv(3, 3, 32, 1, 1, biased=False, relu=False, name='conv2_3_3x3')
             .batch_normalization(relu=True, name='conv2_3_3x3_bn')
             .conv(1, 1, 128, 1, 1, biased=False, relu=False, name='conv2_3_1x1_increase')
             .batch_normalization(relu=False, name='conv2_3_1x1_increase_bn'))

        (self.feed('conv2_2/relu',
                   'conv2_3_1x1_increase_bn')
             .add(name='conv2_3')
             .relu(name='conv2_3/relu')
             .conv(1, 1, 256, 2, 2, biased=False, relu=False, name='conv3_1_1x1_proj')
             .batch_normalization(relu=False, name='conv3_1_1x1_proj_bn'))

        (self.feed('conv2_3/relu')
             .conv(1, 1, 64, 2, 2, biased=False, relu=False, name='conv3_1_1x1_reduce')
             .batch_normalization(relu=True, name='conv3_1_1x1_reduce_bn')
             .zero_padding(paddings=1, name='padding4')
             .conv(3, 3, 64, 1, 1, biased=False, relu=False, name='conv3_1_3x3')
             .batch_normalization(relu=True, name='conv3_1_3x3_bn')
             .conv(1, 1, 256, 1, 1, biased=False, relu=False, name='conv3_1_1x1_increase')
             .batch_normalization(relu=False, name='conv3_1_1x1_increase_bn'))

        (self.feed('conv3_1_1x1_proj_bn',
                   'conv3_1_1x1_increase_bn')
             .add(name='conv3_1')
             .relu(name='conv3_1/relu')
             .interp(factor=0.5, name='conv3_1_sub4')
             .conv(1, 1, 64, 1, 1, biased=False, relu=False, name='conv3_2_1x1_reduce')
             .batch_normalization(relu=True, name='conv3_2_1x1_reduce_bn')
             .zero_padding(paddings=1, name='padding5')
             .conv(3, 3, 64, 1, 1, biased=False, relu=False, name='conv3_2_3x3')
             .batch_normalization(relu=True, name='conv3_2_3x3_bn')
             .conv(1, 1, 256, 1, 1, biased=False, relu=False, name='conv3_2_1x1_increase')
             .batch_normalization(relu=False, name='conv3_2_1x1_increase_bn'))

基础层有两个常见的单元:

Semantic Segmentation--ICNet for Real-Time Semantic Segmentation on High-Resolution Images论文解读_第12张图片

可以看到主分支的前两个卷积都降通道数了,这可以保持分割结果的同时大幅度减少计算量。
普通的残差模块: 在后面空洞卷积后替换主分支中间的卷积.
特殊的残差模块功能有:增加通道,降采样(配合增加通道使用),带空洞卷积等。

上面一段代码示意图如下:

Semantic Segmentation--ICNet for Real-Time Semantic Segmentation on High-Resolution Images论文解读_第13张图片

总结一下,假设原输入为 ( 1024 , 1024 , 3 ) (1024,1024,3) (1024,1024,3)

  • 先将图片长宽resize到原本的1/2大小,即 ( 512 , 512 , 3 ) (512,512,3) (512,512,3)
  • ( C o n v _ b n ) 3 (Conv\_bn)_3 (Conv_bn)3,卷积降采样–>卷积–>卷积,前段提取特征得到 ( 256 , 256 , 64 ) (256,256,64) (256,256,64)
  • 最大池化得到 ( 127 , 127 , 64 ) (127,127,64) (127,127,64)
  • 使用通道增加模块得到 ( 127 , 127 , 128 ) (127,127,128) (127,127,128)
  • 做两个普通残差模块还是 ( 127 , 127 , 128 ) (127,127,128) (127,127,128)
  • 做了一个降采样通道增加模块,得到 C o n v 3 _ 1 / r e l u Conv3\_1/relu Conv3_1/relu的shape为 ( 64 , 64 , 256 ) (64,64,256) (64,64,256)

这里对应的ICNet的示意图:

Semantic Segmentation--ICNet for Real-Time Semantic Segmentation on High-Resolution Images论文解读_第14张图片
可以看到在ICNet上,已经实现了中间分支的CONVS部分了,因为上分支和中间分支在前面的卷积计算是共享的,下面就是实现上分支的剩下部分。

上分支完整结构

需要注意的是上分支的输入是中间分支的feature降采样出来的,而不是将原图直接降采样为1/4输入的。这部分可以参考paper放出来的代码icnet_cityscapes_bnnomerge.prototxt.

        # 这里是截取的代码,与上一段代码有部分重复
        (self.feed('conv3_1_1x1_proj_bn',
                   'conv3_1_1x1_increase_bn')
             .add(name='conv3_1')
             .relu(name='conv3_1/relu')
             .interp(factor=0.5, name='conv3_1_sub4')  # 这里是在feature的基础上直接降采样
             .conv(1, 1, 64, 1, 1, biased=False, relu=False, name='conv3_2_1x1_reduce')
             .batch_normalization(relu=True, name='conv3_2_1x1_reduce_bn')
             .zero_padding(paddings=1, name='padding5')
             .conv(3, 3, 64, 1, 1, biased=False, relu=False, name='conv3_2_3x3')
             .batch_normalization(relu=True, name='conv3_2_3x3_bn')
             .conv(1, 1, 256, 1, 1, biased=False, relu=False, name='conv3_2_1x1_increase')
             .batch_normalization(relu=False, name='conv3_2_1x1_increase_bn'))

        (self.feed('conv3_1_sub4',
                   'conv3_2_1x1_increase')
             .add(name='conv3_2')
             .relu(name='conv3_2/relu')
             .conv(1, 1, 64, 1, 1, biased=False, relu=False, name='conv3_3_1x1_reduce')
             .batch_normalization(relu=True, name='conv3_3_1x1_reduce_bn')
             .zero_padding(paddings=1, name='padding6')
             .conv(3, 3, 64, 1, 1, biased=False, relu=False, name='conv3_3_3x3')
             .batch_normalization(relu=True, name='conv3_3_3x3_bn')
             .conv(1, 1, 256, 1, 1, biased=False, relu=False, name='conv3_3_1x1_increase')
             .batch_normalization(relu=False, name='conv3_3_1x1_increase_bn'))


        (self.feed('conv3_2/relu',
                   'conv3_3_1x1_increase_bn')
             .add(name='conv3_3')
             .relu(name='conv3_3/relu')
             .conv(1, 1, 64, 1, 1, biased=False, relu=False, name='conv3_4_1x1_reduce')
             .batch_normalization(relu=True, name='conv3_4_1x1_reduce_bn')
             .zero_padding(paddings=1, name='padding7')
             .conv(3, 3, 64, 1, 1, biased=False, relu=False, name='conv3_4_3x3')
             .batch_normalization(relu=True, name='conv3_4_3x3_bn')
             .conv(1, 1, 256, 1, 1, biased=False, relu=False, name='conv3_4_1x1_increase')
             .batch_normalization(relu=False, name='conv3_4_1x1_increase_bn'))

        (self.feed('conv3_3/relu',
                   'conv3_4_1x1_increase_bn')
             .add(name='conv3_4')
             .relu(name='conv3_4/relu')
             .conv(1, 1, 512, 1, 1, biased=False, relu=False, name='conv4_1_1x1_proj')
             .batch_normalization(relu=False, name='conv4_1_1x1_proj_bn'))

        (self.feed('conv3_4/relu')
             .conv(1, 1, 128, 1, 1, biased=False, relu=False, name='conv4_1_1x1_reduce')
             .batch_normalization(relu=True, name='conv4_1_1x1_reduce_bn')
             .zero_padding(paddings=2, name='padding8')
             .atrous_conv(3, 3, 128, 2, biased=False, relu=False, name='conv4_1_3x3')
             .batch_normalization(relu=True, name='conv4_1_3x3_bn')
             .conv(1, 1, 512, 1, 1, biased=False, relu=False, name='conv4_1_1x1_increase')
             .batch_normalization(relu=False, name='conv4_1_1x1_increase_bn'))

        (self.feed('conv4_1_1x1_proj_bn',
                   'conv4_1_1x1_increase_bn')
             .add(name='conv4_1')
             .relu(name='conv4_1/relu')
             .conv(1, 1, 128, 1, 1, biased=False, relu=False, name='conv4_2_1x1_reduce')
             .batch_normalization(relu=True, name='conv4_2_1x1_reduce_bn')
             .zero_padding(paddings=2, name='padding9')
             .atrous_conv(3, 3, 128, 2, biased=False, relu=False, name='conv4_2_3x3')
             .batch_normalization(relu=True, name='conv4_2_3x3_bn')
             .conv(1, 1, 512, 1, 1, biased=False, relu=False, name='conv4_2_1x1_increase')
             .batch_normalization(relu=False, name='conv4_2_1x1_increase_bn'))

        (self.feed('conv4_1/relu',
                   'conv4_2_1x1_increase_bn')
             .add(name='conv4_2')
             .relu(name='conv4_2/relu')
             .conv(1, 1, 128, 1, 1, biased=False, relu=False, name='conv4_3_1x1_reduce')
             .batch_normalization(relu=True, name='conv4_3_1x1_reduce_bn')
             .zero_padding(paddings=2, name='padding10')
             .atrous_conv(3, 3, 128, 2, biased=False, relu=False, name='conv4_3_3x3')
             .batch_normalization(relu=True, name='conv4_3_3x3_bn')
             .conv(1, 1, 512, 1, 1, biased=False, relu=False, name='conv4_3_1x1_increase')
             .batch_normalization(relu=False, name='conv4_3_1x1_increase_bn'))

        (self.feed('conv4_2/relu',
                   'conv4_3_1x1_increase')
             .add(name='conv4_3')
             .relu(name='conv4_3/relu')
             .conv(1, 1, 128, 1, 1, biased=False, relu=False, name='conv4_4_1x1_reduce')
             .batch_normalization(relu=True, name='conv4_4_1x1_reduce_bn')
             .zero_padding(paddings=2, name='padding11')
             .atrous_conv(3, 3, 128, 2, biased=False, relu=False, name='conv4_4_3x3')
             .batch_normalization(relu=True, name='conv4_4_3x3_bn')
             .conv(1, 1, 512, 1, 1, biased=False, relu=False, name='conv4_4_1x1_increase')
             .batch_normalization(relu=False, name='conv4_4_1x1_increase_bn'))

        (self.feed('conv4_3/relu',
                   'conv4_4_1x1_increase_bn')
             .add(name='conv4_4')
             .relu(name='conv4_4/relu')
             .conv(1, 1, 128, 1, 1, biased=False, relu=False, name='conv4_5_1x1_reduce')
             .batch_normalization(relu=True, name='conv4_5_1x1_reduce_bn')
             .zero_padding(paddings=2, name='padding12')
             .atrous_conv(3, 3, 128, 2, biased=False, relu=False, name='conv4_5_3x3')
             .batch_normalization(relu=True, name='conv4_5_3x3_bn')
             .conv(1, 1, 512, 1, 1, biased=False, relu=False, name='conv4_5_1x1_increase')
             .batch_normalization(relu=False, name='conv4_5_1x1_increase_bn'))

        (self.feed('conv4_4/relu',
                   'conv4_5_1x1_increase_bn')
             .add(name='conv4_5')
             .relu(name='conv4_5/relu')
             .conv(1, 1, 128, 1, 1, biased=False, relu=False, name='conv4_6_1x1_reduce')
             .batch_normalization(relu=True, name='conv4_6_1x1_reduce_bn')
             .zero_padding(paddings=2, name='padding13')
             .atrous_conv(3, 3, 128, 2, biased=False, relu=False, name='conv4_6_3x3')
             .batch_normalization(relu=True, name='conv4_6_3x3_bn')
             .conv(1, 1, 512, 1, 1, biased=False, relu=False, name='conv4_6_1x1_increase')
             .batch_normalization(relu=False, name='conv4_6_1x1_increase_bn'))

        (self.feed('conv4_5/relu',
                   'conv4_6_1x1_increase_bn')
             .add(name='conv4_6')
             .relu(name='conv4_6/relu')
             .conv(1, 1, 1024, 1, 1, biased=False, relu=False, name='conv5_1_1x1_proj')
             .batch_normalization(relu=False, name='conv5_1_1x1_proj_bn'))

        (self.feed('conv4_6/relu')
             .conv(1, 1, 256, 1, 1, biased=False, relu=False, name='conv5_1_1x1_reduce')
             .batch_normalization(relu=True, name='conv5_1_1x1_reduce_bn')
             .zero_padding(paddings=4, name='padding14')
             .atrous_conv(3, 3, 256, 4, biased=False, relu=False, name='conv5_1_3x3')
             .batch_normalization(relu=True, name='conv5_1_3x3_bn')
             .conv(1, 1, 1024, 1, 1, biased=False, relu=False, name='conv5_1_1x1_increase')
             .batch_normalization(relu=False, name='conv5_1_1x1_increase_bn'))

        (self.feed('conv5_1_1x1_proj_bn',
                   'conv5_1_1x1_increase_bn')
             .add(name='conv5_1')
             .relu(name='conv5_1/relu')
             .conv(1, 1, 256, 1, 1, biased=False, relu=False, name='conv5_2_1x1_reduce')
             .batch_normalization(relu=True, name='conv5_2_1x1_reduce_bn')
             .zero_padding(paddings=4, name='padding15')
             .atrous_conv(3, 3, 256, 4, biased=False, relu=False, name='conv5_2_3x3')
             .batch_normalization(relu=True, name='conv5_2_3x3_bn')
             .conv(1, 1, 1024, 1, 1, biased=False, relu=False, name='conv5_2_1x1_increase')
             .batch_normalization(relu=False, name='conv5_2_1x1_increase_bn'))

        (self.feed('conv5_1/relu',
                   'conv5_2_1x1_increase_bn')
             .add(name='conv5_2')
             .relu(name='conv5_2/relu')
             .conv(1, 1, 256, 1, 1, biased=False, relu=False, name='conv5_3_1x1_reduce')
             .batch_normalization(relu=True, name='conv5_3_1x1_reduce_bn')
             .zero_padding(paddings=4, name='padding16')
             .atrous_conv(3, 3, 256, 4, biased=False, relu=False, name='conv5_3_3x3')
             .batch_normalization(relu=True, name='conv5_3_3x3_bn')
             .conv(1, 1, 1024, 1, 1, biased=False, relu=False, name='conv5_3_1x1_increase')
             .batch_normalization(relu=False, name='conv5_3_1x1_increase_bn'))

        (self.feed('conv5_2/relu',
                   'conv5_3_1x1_increase_bn')
             .add(name='conv5_3')
             .relu(name='conv5_3/relu'))

        shape = self.layers['conv5_3/relu'].get_shape().as_list()[1:3]
        h, w = shape
        
        (self.feed('conv5_3/relu')
             .avg_pool(h, w, h, w, name='conv5_3_pool1')
             .resize_bilinear(shape, name='conv5_3_pool1_interp'))

        (self.feed('conv5_3/relu')
             .avg_pool(h/2, w/2, h/2, w/2, name='conv5_3_pool2')
             .resize_bilinear(shape, name='conv5_3_pool2_interp'))

        (self.feed('conv5_3/relu')
             .avg_pool(h/3, w/3, h/3, w/3, name='conv5_3_pool3')
             .resize_bilinear(shape, name='conv5_3_pool3_interp'))

        (self.feed('conv5_3/relu')
             .avg_pool(h/4, w/4, h/4, w/4, name='conv5_3_pool6')
             .resize_bilinear(shape, name='conv5_3_pool6_interp'))

        (self.feed('conv5_3/relu',
                   'conv5_3_pool6_interp',
                   'conv5_3_pool3_interp',
                   'conv5_3_pool2_interp',
                   'conv5_3_pool1_interp')
             .add(name='conv5_3_sum')
             .conv(1, 1, 256, 1, 1, biased=False, relu=False, name='conv5_4_k1')
             .batch_normalization(relu=True, name='conv5_4_k1_bn')
             .interp(factor=2.0, name='conv5_4_interp')
             .zero_padding(paddings=2, name='padding17')
             .atrous_conv(3, 3, 128, 2, biased=False, relu=False, name='conv_sub4')
             .batch_normalization(relu=False, name='conv_sub4_bn'))

代码示意图如下:

Semantic Segmentation--ICNet for Real-Time Semantic Segmentation on High-Resolution Images论文解读_第15张图片

总结一下,经过中间分支的输出即 C o n v 3 _ 1 / r e l u Conv3\_1/relu Conv3_1/relu部分,大小为 ( 64 , 64 , 256 ) (64,64,256) (64,64,256)

  • 先将feature长宽resize到原本的1/2大小,即 ( 32 , 32 , 256 ) (32,32,256) (32,32,256)
  • 做三个普通的残差模块,依旧是 ( 32 , 32 , 256 ) (32,32,256) (32,32,256)
  • 使用通道增加的空洞卷积模块,输出为 ( 32 , 32 , 512 ) (32,32,512) (32,32,512)
  • 使用五个普通的空洞残差模块(即普通残差模块主分支的卷积换为空洞卷积),输出为 ( 32 , 32 , 512 ) (32,32,512) (32,32,512)
  • 使用通道增加的空洞卷积模块,输出为 ( 32 , 32 , 1024 ) (32,32,1024) (32,32,1024)
  • 使用两个普通的空洞残差模块,得到 C o n v 5 _ 3 / r e l u Conv5\_3/relu Conv5_3/relu,大小为 ( 32 , 32 , 1024 ) (32,32,1024) (32,32,1024)

到这里,上分支的空洞卷积部分结束了,接下来就做通道降维,即上图下面的部分:

  • C o n v 5 _ 3 / r e l u Conv5\_3/relu Conv5_3/relu送入PSP模块

    • 即送入4个平均池化层,输出为 ( 1 , 1 , 1024 ) , ( 2 , 2 , 1024 ) , ( 3 , 3 , 1024 ) , ( 6 , 6 , 1024 ) (1,1,1024),(2,2,1024),(3,3,1024),(6,6,1024) (1,1,1024)(2,2,1024)(3,3,1024)(6,6,1024)
    • 再通过双线性插值将4个池化输出放缩回 ( 32 , 32 , 1024 ) (32,32,1024) (32,32,1024)
    • 将4个放缩输出 ( 32 , 32 , 1024 ) (32,32,1024) (32,32,1024) C o n v 5 _ 3 / r e l u Conv5\_3/relu Conv5_3/relu做像素加操作,得到 ( 32 , 32 , 1024 ) (32,32,1024) (32,32,1024)
  • 再卷积将通道降下来,即 C o n v 5 _ 4 _ k 1 _ b n Conv5\_4\_k1\_bn Conv5_4_k1_bn,大小为 ( 32 , 32 , 256 ) (32,32,256) (32,32,256)

这里对应的ICNet的示意图:

Semantic Segmentation--ICNet for Real-Time Semantic Segmentation on High-Resolution Images论文解读_第16张图片

可以看到在ICNet上,上分支的提取部分已经实现了.下面就是实现CFF和输出结果了。

CFF单元和输出

        # 截取代码,与上面代码有重复
        (self.feed('conv5_3/relu',
                   'conv5_3_pool6_interp',
                   'conv5_3_pool3_interp',
                   'conv5_3_pool2_interp',
                   'conv5_3_pool1_interp')
             .add(name='conv5_3_sum')
             .conv(1, 1, 256, 1, 1, biased=False, relu=False, name='conv5_4_k1')
             .batch_normalization(relu=True, name='conv5_4_k1_bn')
             .interp(factor=2.0, name='conv5_4_interp')
             .zero_padding(paddings=2, name='padding17')
             .atrous_conv(3, 3, 128, 2, biased=False, relu=False, name='conv_sub4')
             .batch_normalization(relu=False, name='conv_sub4_bn'))

        (self.feed('conv3_1/relu')
             .conv(1, 1, 128, 1, 1, biased=False, relu=False, name='conv3_1_sub2_proj')
             .batch_normalization(relu=False, name='conv3_1_sub2_proj_bn'))

        (self.feed('conv_sub4_bn',
                   'conv3_1_sub2_proj_bn')
             .add(name='sub24_sum')
             .relu(name='sub24_sum/relu')
             .interp(factor=2.0, name='sub24_sum_interp')
             .zero_padding(paddings=2, name='padding18')
             .atrous_conv(3, 3, 128, 2, biased=False, relu=False, name='conv_sub2')
             .batch_normalization(relu=False, name='conv_sub2_bn'))

        (self.feed('data')
             .conv(3, 3, 32, 2, 2, biased=False, padding='SAME', relu=False, name='conv1_sub1')
             .batch_normalization(relu=True, name='conv1_sub1_bn')
             .conv(3, 3, 32, 2, 2, biased=False, padding='SAME', relu=False, name='conv2_sub1')
             .batch_normalization(relu=True, name='conv2_sub1_bn')
             .conv(3, 3, 64, 2, 2, biased=False, padding='SAME', relu=False, name='conv3_sub1')
             .batch_normalization(relu=True, name='conv3_sub1_bn')
             .conv(1, 1, 128, 1, 1, biased=False, relu=False, name='conv3_sub1_proj')
             .batch_normalization(relu=False, name='conv3_sub1_proj_bn'))

        (self.feed('conv_sub2_bn',
                   'conv3_sub1_proj_bn')
             .add(name='sub12_sum')
             .relu(name='sub12_sum/relu')
             .interp(factor=2.0, name='sub12_sum_interp')
             .conv(1, 1, num_classes, 1, 1, biased=True, relu=False, name='conv6_cls'))

        (self.feed('conv5_4_interp')
             .conv(1, 1, num_classes, 1, 1, biased=True, relu=False, name='sub4_out'))

        (self.feed('sub24_sum_interp')
             .conv(1, 1, num_classes, 1, 1, biased=True, relu=False, name='sub24_out'))

代码示意图如下:

总结一下,经过PSP模块的输出即 C o n v 5 _ 4 _ k 1 _ b n Conv5\_4\_k1\_bn Conv5_4_k1_bn,大小为 ( 32 , 32 , 256 ) (32,32,256) (32,32,256)

  • 先做第一个CFF模块,完成上分支到中分支的特征融合
    • 先上采样2倍到 ( 64 , 64 , 256 ) (64,64,256) (64,64,256),经过一个空洞卷积将通道降维到 ( 64 , 64 , 128 ) (64,64,128) (64,64,128)
    • C o n v 3 _ 1 / r e l u Conv3\_1/relu Conv3_1/relu做卷积通道降维到 ( 64 , 64 , 128 ) (64,64,128) (64,64,128),与上面输出作像素加,结果加 r e l u relu relu ( 64 , 64 , 128 ) (64,64,128) (64,64,128)
    • 再上采样的基础上,再作卷积输出得到上分支的预测结果,即 s u b 4 _ o u t sub4\_out sub4_out对应1/16标签。输出为 ( 64 , 64 , n u m _ c l a s s ) (64,64,num\_class) (64,64,num_class)
  • 再做第二个CFF模块,完成中分支到下分支的特征融合
    • 先上采样2倍到 ( 128 , 128 , 128 ) (128,128,128) (128,128,128),经过一个空洞卷积得到 ( 128 , 128 , 128 ) (128,128,128) (128,128,128)
    • 将输入图片做4次卷积,分别是降采样卷积–>降采样卷积–>降采样卷积–>卷积,得到 ( 128 , 128 , 128 ) (128,128,128) (128,128,128),与上面输出作像素加,结果加 r e l u relu relu ( 128 , 128 , 128 ) (128,128,128) (128,128,128)
    • 再上采样的基础上,再作卷积输出得到中分支的预测结果,即 s u b 24 _ o u t sub24\_out sub24_out对应1/8标签。输出为 ( 128 , 128 , n u m _ c l a s s ) (128,128,num\_class) (128,128,num_class)
  • 做上采样2倍到 ( 256 , 256 , 128 ) (256,256,128) (256,256,128),作卷积输出得到下分支的预测结果,即 C o n v 6 _ c l s Conv6\_cls Conv6_cls对应1/4标签。输出为 ( 256 , 256 , n u m _ c l a s s ) (256,256,num\_class) (256,256,num_class)

到这里,ICNet基本上算是定义完成了

对应CFF模块结构:

Semantic Segmentation--ICNet for Real-Time Semantic Segmentation on High-Resolution Images论文解读_第17张图片

这里对应的ICNet的示意图:

Semantic Segmentation--ICNet for Real-Time Semantic Segmentation on High-Resolution Images论文解读_第18张图片

训练代码

上面说完了ICNet的模型构建,下面看一下ICNet的损失函数和最终的输出,代码见train.py.

主要是构建多分支loss,即论文中公式为 L = λ 1 L 1 + λ 2 L 2 + λ 3 L 3 L=\lambda_1L_1+\lambda_2L_2+\lambda_3L_3 L=λ1L1+λ2L2+λ3L3.
其中 λ 1 = 0.16 , λ 2 = 0.4 , λ 3 = 1 \lambda_1=0.16,\lambda_2=0.4,\lambda_3=1 λ1=0.16,λ2=0.4,λ3=1.

这里主要看main函数的代码:

def main():
    """Create the model and start the training."""
    args = get_arguments()
    
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    
    coord = tf.train.Coordinator() # 获取多线程管理器
    
    with tf.name_scope("create_inputs"):
        reader = ImageReader(
            ' ',
            args.data_list,
            input_size,
            args.random_scale,
            args.random_mirror,
            args.ignore_label,
            IMG_MEAN,
            coord)
        image_batch, label_batch = reader.dequeue(args.batch_size)
    
    # 构建模型
    net = ICNet_BN({'data': image_batch}, is_training=True, num_classes=args.num_classes)
    
    # 获取上中下分支的输出
    sub4_out = net.layers['sub4_out']
    sub24_out = net.layers['sub24_out']
    sub124_out = net.layers['conv6_cls']

    restore_var = tf.global_variables()
    all_trainable = [v for v in tf.trainable_variables() if ('beta' not in v.name and 'gamma' not in v.name) or args.train_beta_gamma]
   
    loss_sub4 = create_loss(sub4_out, label_batch, args.num_classes, args.ignore_label)
    loss_sub24 = create_loss(sub24_out, label_batch, args.num_classes, args.ignore_label)
    loss_sub124 = create_loss(sub124_out, label_batch, args.num_classes, args.ignore_label)
    l2_losses = [args.weight_decay * tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'weights' in v.name]
    
    # 构建带L2正则的多分支loss
    reduced_loss = LAMBDA1 * loss_sub4 +  LAMBDA2 * loss_sub24 + LAMBDA3 * loss_sub124 + tf.add_n(l2_losses)

    # Using Poly learning rate policy 
    base_lr = tf.constant(args.learning_rate)
    step_ph = tf.placeholder(dtype=tf.float32, shape=())
    learning_rate = tf.scalar_mul(base_lr, tf.pow((1 - step_ph / args.num_steps), args.power))
    
    # Gets moving_mean and moving_variance update operations from tf.GraphKeys.UPDATE_OPS
    if args.update_mean_var == False:
        update_ops = None
    else:
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        opt_conv = tf.train.MomentumOptimizer(learning_rate, args.momentum)
        grads = tf.gradients(reduced_loss, all_trainable)
        train_op = opt_conv.apply_gradients(zip(grads, all_trainable))
        
    # Set up tf session and initialize variables. 
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()
    
    sess.run(init)
    
    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=5)

    ckpt = tf.train.get_checkpoint_state(args.snapshot_dir)
    if ckpt and ckpt.model_checkpoint_path:
        loader = tf.train.Saver(var_list=restore_var)
        load_step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
        load(loader, sess, ckpt.model_checkpoint_path)
    else:
        print('Restore from pre-trained model...')
        net.load(args.restore_from, sess)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Iterate over training steps.
    for step in range(args.num_steps):
        start_time = time.time()
        
        feed_dict = {step_ph: step}
        if step % args.save_pred_every == 0:
            loss_value, loss1, loss2, loss3, _ = sess.run([reduced_loss, loss_sub4, loss_sub24, loss_sub124, train_op], feed_dict=feed_dict)
            save(saver, sess, args.snapshot_dir, step)
        else:
            loss_value, loss1, loss2, loss3, _ = sess.run([reduced_loss, loss_sub4, loss_sub24, loss_sub124, train_op], feed_dict=feed_dict)
        duration = time.time() - start_time
        print('step {:d} \t total loss = {:.3f}, sub4 = {:.3f}, sub24 = {:.3f}, sub124 = {:.3f} ({:.3f} sec/step)'.format(step, loss_value, loss1, loss2, loss3, duration))
        
    coord.request_stop()
    coord.join(threads)


测试代码

上面说完了ICNet的模型训练,下面看一下ICNet的推断处理,代码见inference.py.

这里和原论文一致,处理的比较简单,参考paper给的代码icnet_cityscapes_bnnomerge.prototxt后面.

**将 C o n v 6 _ c l s Conv6\_cls Conv6_cls的结果取出,直接放缩到预测大小。**看main函数的代码:

def main():
    args = get_arguments()
    
    img, filename = load_img(args.img_path)
    shape = img.shape[0:2]

    x = tf.placeholder(dtype=tf.float32, shape=img.shape)
    img_tf = preprocess(x)
    img_tf, n_shape = check_input(img_tf)

    # Create network.
    if args.model[-2:] == 'bn':
        net = ICNet_BN({'data': img_tf}, num_classes=num_classes)
    elif args.model == 'others':
        net = ICNet_BN({'data': img_tf}, num_classes=num_classes)
    else:
        net = ICNet({'data': img_tf}, num_classes=num_classes)
    
    # 取出下分支的预测结果
    raw_output = net.layers['conv6_cls']
    
    # 预测,直接双线性放缩到指定预测大小 完事
    raw_output_up = tf.image.resize_bilinear(raw_output, size=n_shape, align_corners=True)
    raw_output_up = tf.image.crop_to_bounding_box(raw_output_up, 0, 0, shape[0], shape[1])
    raw_output_up = tf.argmax(raw_output_up, dimension=3)
    pred = decode_labels(raw_output_up, shape, num_classes)

    # Init tf Session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)

    restore_var = tf.global_variables()
    
    if args.model == 'train':
        print('Restore from train30k model...')
        net.load(model_train30k, sess)
    elif args.model == 'trainval':
        print('Restore from trainval90k model...')
        net.load(model_trainval90k, sess)
    elif args.model == 'train_bn':
        print('Restore from train30k bnnomerge model...')
        net.load(model_train30k_bn, sess)
    elif args.model == 'trainval_bn':
        print('Restore from trainval90k bnnomerge model...')
        net.load(model_trainval90k_bn, sess)
    else:
        ckpt = tf.train.get_checkpoint_state(snapshot_dir)
        if ckpt and ckpt.model_checkpoint_path:
            loader = tf.train.Saver(var_list=restore_var)
            load_step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
            load(loader, sess, ckpt.model_checkpoint_path)

    preds = sess.run(pred, feed_dict={x: img})

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    misc.imsave(args.save_dir + filename, preds[0])


你可能感兴趣的:(深度学习,语义分割,语义分割-目标检测论文解读,语义分割,深度学习,计算机视觉,ICNet,TensorFlow)