Backbone之Senet介绍

Backbone之Senet介绍

论文:https://arxiv.org/pdf/1709.01507.pdf
代码:https://github.com/moskomule/senet.pytorch

Senet是Momenta在2017的cvpr上发布的,并且获得了ImageNet当年也应该是最后一届冠军,senet相比较于resnet通过se分支增加了类似nlp的Attention 机制,来更加关注channel之间的关系,因为对于一张图片来说,有RGB三个通道,可能只有一个通道有用,其他通道价值就很低,在resnet每个block上,比如说输出512个channel,可能就几十个channel有利于结果,那么其他的可能就是干扰,通俗的来说,就是给重要的通道一个更大的权重,同时抑制不必要的通道。

如下图所示,Senet的核心就是Squeeze和Excitation操作,换句话说,senet的核心就是这个se分支,简单,有效,同时没有明显的参数增加,其实我觉得好的网络结构就应该这样,大道至简。
Backbone之Senet介绍_第1张图片
Backbone之Senet介绍_第2张图片

Backbone之Senet介绍_第3张图片

一个是inception网络上添加se分支 一个是在resnet网络上添加se分支
可以看出在backbone上添加一个se分支其实是一件相当容易的事情。

总结一下就是,原来这个block输出的是HxWxC的,然后通过global池化到Cx1x1,然后进入第一个fc层,输出1/16Cx1x1,然后进入第二个fc层输出Cx1x1,这样做的好处是比直接用一个fc层具有更多的非线性。可以更好地拟合通道之间复杂的相关性,同时减少计算。最后通过一个sigmod层,因为这样就可以表达出通道之间的相关性(全部加起来等于1),输出Cx1x1直接与block输出结果按元素乘

下面我们来看代码的实现:
首先是se模块:

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

正如同论文的SE分支一样

Backbone之Senet介绍_第4张图片
其实就一个global pooling+fc+relu+fc+sigmoid层,简单易实现这里x是上一个block的输出,其实se分支添加进backbone的主干方法有很多,根据残差块和SE分支位置的差异,作者其实做了相当多的对比试验

Backbone之Senet介绍_第5张图片
Backbone之Senet介绍_第6张图片
Backbone之Senet介绍_第7张图片
Backbone之Senet介绍_第8张图片
那么主要是有这四种,结论是这样的:

Backbone之Senet介绍_第9张图片
同时在SE分支第一个FC层的衰减上r上,作者也做了相应的对比试验:

Backbone之Senet介绍_第10张图片
很明显可以看出,随着r的增加,参数量逐渐减少,但是分类精度也随之减低,
为了均衡参数/精度,作者选取的是16

Backbone之Senet介绍_第11张图片
作者同时对比了se分支激活函数,发现sigmoid函数效果更好

Backbone之Senet介绍_第12张图片
在所有阶段添加SE分支比在某一个阶段添加要好

最后效果(这个是论文的)
Backbone之Senet介绍_第13张图片
在我们实际使用发现,se分支确实有效,添加se分支之后backbone的特征提取能力有明显提升,同时耗时没有明显增加

你可能感兴趣的:(机器学习,神经网络,pytorch,深度学习)