精简CNN模型系列之三:SkipNet

介绍

CNN模型为了追求精度提高层数已经是愈来愈多,可更多的层次带来的精度边际提升却不断减小。或者对某些输入图片而言,真正所需的layers并非那么多,只有一些真正模糊、特征不明显、即使人看上去也较难分辨的图片才需要较多的layers处理最终得到能分别其类别的表达特征。

SkipNet主要是以此假设出发,通过在传统CNN的每个layer(或module)上设置判断其是否需要执行的Gate module来决定是否需要真的执行此层计算,若判断为否则直接将activation feature maps传入到下一层,越过当下层的运算不做。无益这样做可以有效地节省传统CNN模型在部署时进行推理工作所需的时间。

就这样一旦训练好,SkipNet在做图片推理时可根据输入的feature maps不同灵活地决定是否执行某一网络中的层。下图可反映SkipNet这一根本特点。

SkipNet根本思想

SkiptNet

对于每一层操作而言,SkipNet可表示为:xi+1 = GiFi(xi)+(1-Gi)xi。其中xi和Fi(xi)分别表示第ith layer的输入与输出feature maps;Gi ∈{0,1} 为第ith layer的Gate函数。

对于此处的Gate函数,作者实验了两种不同的表示方法。Paper中SkipNet基于的CNN网络为Resnet,其中Gate即可以被独立地添加在各个Residual block上面作为单独的个体,有着不同的参数即Feed-forward Gate;还可以所有的Residual blocks复用一个Gate module即Recurrent Gate。其不同之处可从下图中看出。

SkipNet中两种不同的Gate函数选择

Gate module设计

作者在论文中共尝试了三种不同的Gate module设计,它们对计算与accuracy的考量略有不同。

FFGate-I: MaxPool(2x2) -> Conv(3x3, 1) -> Conv(3x3, 2) -> AvgPool -> FC,整体计算量约为Residual block的19%,在论文中主要用于较浅的一些网络(层数小于100);
FFGate-II: Conv(3x3, 2) -> AvgPool -> FC,整体计算量约为Residual block的12.5%,主要用于较深的一些网络(层数大于100);
RNNGate: AvgPool -> Conv(1x1) -> LSTM(10 hidden units) -> FC,整体计算量约为Residual block的0.04%,是论文中首选的Gate函数。在深层次网络中它相对于Feed-forward Gate有较大的性能与分类精度优势,只是在较浅的层次上它精度略低,但计算开销仍有较大优势。

下图为以上三种Gate module的概况描述。

三种具体的Gate_module设计

使用Hybrid RL的Skipping policy学习

对于上节所介绍的Gate函数可理解为是这么一种决策:Π(xi,i) = P(Gi(xi) = gi),(其中gi∈{0,1},分别表示执行还是略过第ith层执行的两种离散决策)。

这样对于有N层的CNN来说,我们在forward时需要决定下如此一个输入为x的决策序列:g = [g1,....,gN] ˜ Π(Fθ)。在这里Fθ = [Fθ1,....,FθN]表示CNN网络中N个layers的计算。

而整体的目标函数则可表示如下:

Skip_learning中使用Hybrid_RL时的整体目标函数

其中Ri = (1-gi)Ci表示的是每个Gate module所节省的计算,亦为它的激励函数。因为paper中用的是Resnet,故假定所有的Ci相同,设为1。然后α 则为CNN分类准确率与计算节省之间的平衡系数。可以看出这里的目标函数设计同时考虑了模型分类精度与计算效率并力图在其中寻找平衡。

下式为具体计算时的梯度计算公式。可以看出它主要由两部分组成,第一部分表示的是学习分类精度的supervised loss,第二部分则是要接合RL最终学习出来的反映计算节省的Skip learning policy。

Skip_learning中使用Hybrid_RL时的梯度计算

下图为使用Hybrid RL的具体算法概述。

Hybrid_RL_learning算法

实验结果

下图为SkipNet在各大数据集上得到的分类精度结果。

在各大数据集上SkipNet得到的分类精度

下表中反映了不同SkipNet配置与训练方法在达到与原生ResNet相似精度的情况下换来的计算节省。

不同SkipNet配置在达到相似精度情况下得到的计算节省

代码分析

如下为FFGate-I的设计实现,其它Gate module的写法并无太多不同。

# Feedforward-Gate (FFGate-I)
class FeedforwardGateI(nn.Module):
    """ Use Max Pooling First and then apply to multiple 2 conv layers.
    The first conv has stride = 1 and second has stride = 2"""
    def __init__(self, pool_size=5, channel=10):
        super(FeedforwardGateI, self).__init__()
        self.pool_size = pool_size
        self.channel = channel

        self.maxpool = nn.MaxPool2d(2)
        self.conv1 = conv3x3(channel, channel)
        self.bn1 = nn.BatchNorm2d(channel)
        self.relu1 = nn.ReLU(inplace=True)

        # adding another conv layer
        self.conv2 = conv3x3(channel, channel, stride=2)
        self.bn2 = nn.BatchNorm2d(channel)
        self.relu2 = nn.ReLU(inplace=True)

        pool_size = math.floor(pool_size/2)  # for max pooling
        pool_size = math.floor(pool_size/2 + 0.5)  # for conv stride = 2

        self.avg_layer = nn.AvgPool2d(pool_size)
        self.linear_layer = nn.Conv2d(in_channels=channel, out_channels=2,
                                      kernel_size=1, stride=1)
        self.prob_layer = nn.Softmax()
        self.logprob = nn.LogSoftmax()

    def forward(self, x):
        x = self.maxpool(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = self.avg_layer(x)
        x = self.linear_layer(x).squeeze()
        softmax = self.prob_layer(x)
        logprob = self.logprob(x)

        # discretize output in forward pass.
        # use softmax gradients in backward pass
        x = (softmax[:, 1] > 0.5).float().detach() - \
            softmax[:, 1].detach() + softmax[:, 1]

        x = x.view(x.size(0), 1, 1, 1)
        return x, logprob

下面这个class里面则具体实现了如何将Gate module与某一CNN网络结合起来从而实现相关的SkipNet。

class ResNetFeedForwardRL(nn.Module):
    """Adding gating module on every basic block"""

    def __init__(self, block, layers, num_classes=10,
                 gate_type='ffgate1', **kwargs):
        self.inplanes = 16
        super(ResNetFeedForwardRL, self).__init__()

        self.num_layers = layers
        self.conv1 = conv3x3(3, 16)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)

        self.gate_instances = []
        self.gate_type = gate_type
        self._make_group(block, 16, layers[0], group_id=1,
                         gate_type=gate_type, pool_size=32)
        self._make_group(block, 32, layers[1], group_id=2,
                         gate_type=gate_type, pool_size=16)
        self._make_group(block, 64, layers[2], group_id=3,
                         gate_type=gate_type, pool_size=8)

        # remove the last gate instance, (not optimized)
        del self.gate_instances[-1]

        self.avgpool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64 * block.expansion, num_classes)

        self.softmax = nn.Softmax()
        self.saved_actions = []
        self.rewards = []

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(0) * m.weight.size(1)
                m.weight.data.normal_(0, math.sqrt(2. / n))

    def _make_group(self, block, planes, layers, group_id=1,
                    gate_type='fisher', pool_size=16):
        """ Create the whole group"""
        for i in range(layers):
            if group_id > 1 and i == 0:
                stride = 2
            else:
                stride = 1

            meta = self._make_layer_v2(block, planes, stride=stride,
                                       gate_type=gate_type,
                                       pool_size=pool_size)

            setattr(self, 'group{}_ds{}'.format(group_id, i), meta[0])
            setattr(self, 'group{}_layer{}'.format(group_id, i), meta[1])
            setattr(self, 'group{}_gate{}'.format(group_id, i), meta[2])

            # add into gate instance collection
            self.gate_instances.append(meta[2])

    def _make_layer_v2(self, block, planes, stride=1,
                       gate_type='fisher', pool_size=16):
        """ create one block and optional a gate module """
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),

            )
        layer = block(self.inplanes, planes, stride, downsample)
        self.inplanes = planes * block.expansion

        if gate_type == 'ffgate1':
            gate_layer = RLFeedforwardGateI(pool_size=pool_size,
                                            channel=planes*block.expansion)
        elif gate_type == 'ffgate2':
            gate_layer = RLFeedforwardGateII(pool_size=pool_size,
                                             channel=planes*block.expansion)
        else:
            gate_layer = None

        if downsample:
            return downsample, layer, gate_layer
        else:
            return None, layer, gate_layer

    def repackage_vars(self):
        self.saved_actions = repackage_hidden(self.saved_actions)

    def forward(self, x, reinforce=False):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        masks = []
        gprobs = []
        # must pass through the first layer in first group
        x = getattr(self, 'group1_layer0')(x)
        # gate takes the output of the current layer
        mask, gprob = getattr(self, 'group1_gate0')(x)
        gprobs.append(gprob)
        masks.append(mask.squeeze())
        prev = x  # input of next layer

        for g in range(3):
            for i in range(0 + int(g == 0), self.num_layers[g]):
                if getattr(self, 'group{}_ds{}'.format(g+1, i)) is not None:
                    prev = getattr(self, 'group{}_ds{}'.format(g+1, i))(prev)
                x = getattr(self, 'group{}_layer{}'.format(g+1, i))(x)
                # new mask is taking the current output
                prev = x = mask.expand_as(x) * x \
                           + (1 - mask).expand_as(prev) * prev
                mask, gprob = getattr(self, 'group{}_gate{}'.format(g+1, i))(x)
                gprobs.append(gprob)
                masks.append(mask.squeeze())

        del masks[-1]

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        # collect all actions
        for inst in self.gate_instances:
            self.saved_actions.append(inst.saved_action)

        if reinforce:  # for pure RL
            softmax = self.softmax(x)
            action = softmax.multinomial()
            self.saved_actions.append(action)

        return x, masks, gprobs

参考文献

  • SkipNet: Learning Dynamic Routing in Convolutional Networks, Xin-Wang, 2018
  • https://github.com/ucbdrive/skipnet

你可能感兴趣的:(精简CNN模型系列之三:SkipNet)