SENet在MXNet下的实现(部分代码)

Squeeze-and-Excitation Networks (SENet)获得了2017年ImageNet的分类冠军。
论文地址:https://arxiv.org/abs/1709.01507
本文简单介绍了SENet这篇文章,并附上了SE-ResNet基于MXNet(主要基于是gluon接口)的代码实现。

SENet中,Squeeze和Excitation是两个关键性操作,示意图如下:


d80b0d64610e4610875850b69d68779a_th.jpg

第一步:Squeeze是在空间维度对特征进行压缩,即Global Average Pooling。

第二步:Excitation是用Sigmoid Function为每个特征通道生成权重,权重表示特征通道间的相关性。

第三步:Reweight操作,将Excitation生成的权重通过乘法逐通道加权到CNN提取的特征图上,完成在通道维度上的对原始特征的重标定。

SE模块可以简单地嵌入到任何神经网络当中,下面是SE-ResNet的网络结构图:


SE-ResNet.png

直接上代码:
这是原始的Residual Block,我们拿来做个参考

class Residual(nn.HybridBlock):
    def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
        super(Residual, self).__init__(**kwargs)
        self.conv1 = nn.Conv2D(num_channels, kernel_size=3, padding=1,
                               strides=strides)
        self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2D(num_channels, kernel_size=1,
                                   strides=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm()
        self.bn2 = nn.BatchNorm()

    def forward(self, X):
        Y = nd.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        return nd.relu(Y + X)

重点在这里,SE-Module,为了方便理解我们把Squeeze和Excitation单独写:

def Attention(num_channels):
    net = nn.HybridSequential()
    with net.name_scope():
        net.add(
            nn.GlobalAvgPool2D(),
            nn.Dense(num_channels),
            nn.Activation('relu'),
            nn.Dense(num_channels),
            nn.Activation('sigmoid')
        )
    return net

再将SE-Module嵌入到Residual Block里面去,做一个broadcast_multiply

class SEResidual(nn.HybridBlock):
    def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
        super(SEResidual, self).__init__(**kwargs)
        self.conv1 = nn.Conv2D(num_channels, kernel_size=3, padding=1,
                               strides=strides)
        self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2D(num_channels, kernel_size=1,
                                   strides=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm()
        self.bn2 = nn.BatchNorm()
        self.weight = Attention(num_channels)

    def forward(self, X):
        Y = nd.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        W = Y
        for layer in self.weight: #W就是Attention的权重
            W = layer(W)
        if self.conv3:
            X = self.conv3(X)
        Y = nd.broadcast_mul(Y,nd.reshape(W,shape=(-1,num_channels,1,1)))
        return nd.relu(Y + X)

最后再用SE-Residual Block搭积木就好啦。
啾咪~

你可能感兴趣的:(SENet在MXNet下的实现(部分代码))