轻量型注意力模块:ULSAM

ULSAM: Ultra-Lightweight Subspace Attention Module for Compact Convolutional Neural Networks
论文地址

作者提出了一种新的用于紧凑网络神经网络的注意力块(ULSAM),它可以学习每个特征子空间的个体注意力映射,并能够在多尺度、多频率特征学习的同时高效地学习跨信道信息。
轻量型注意力模块:ULSAM_第1张图片
主要思想:将提取的特征分成g组,对每组的子特征(论文中称问subspace)进行空间上的重新校准,最后,把g组特征concatenate到一起。具体做法看下面代码。首先用1×1 depthwise conv对每组特征提取channel为nin的新特征,然后maxpool,再pointwise conv成channel为1的attention map,最后利用softmax对attention map在H维缩放,确保attention map的权重和为1。

class SubSpace(nn.Module):
    """
    Subspace class.
    ...
    Attributes
    ----------
    nin : int
        number of input feature volume.
    Methods
    -------
    __init__(nin)
        initialize method.
    forward(x)
        forward pass.
    """

    def __init__(self, nin):
        super(SubSpace, self).__init__()
        self.conv_dws = nn.Conv2d(
            nin, nin, kernel_size=1, stride=1, padding=0, groups=nin
        )
        self.bn_dws = nn.BatchNorm2d(nin, momentum=0.9)
        self.relu_dws = nn.ReLU(inplace=False)

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)

        self.conv_point = nn.Conv2d(
            nin, 1, kernel_size=1, stride=1, padding=0, groups=1
        )
        self.bn_point = nn.BatchNorm2d(1, momentum=0.9)
        self.relu_point = nn.ReLU(inplace=False)

        self.softmax = nn.Softmax(dim=2)

    def forward(self, x):
        out = self.conv_dws(x)
        out = self.bn_dws(out)
        out = self.relu_dws(out)

        out = self.maxpool(out)

        out = self.conv_point(out)
        out = self.bn_point(out)
        out = self.relu_point(out)

        m, n, p, q = out.shape
        out = self.softmax(out.view(m, n, -1))
        out = out.view(m, n, p, q)

        out = out.expand(x.shape[0], x.shape[1], x.shape[2], x.shape[3])

        out = torch.mul(out, x)

        out = out + x

        return out

分析一下整个attention map的计算复杂度,主要就是dw conv中的nin×h×w×1×1跟pw conv中nin×h×w×1,同时考虑到原始特征由分组而来,计算量缺失很小。

轻量型注意力模块:ULSAM_第2张图片
从上图结果来看,(POS11:1表示在第11个layer后在使用ULSAM),分组数g=4时可以获得较好的结果。虽然,整体增益不大,但考虑在整体计算复杂度基本不变,此方法确实有些意思。但有个问题,该方法增加了add,mul等elment-wise操作,这也会增加计算负担,并且在flops和param中无法体现。

你可能感兴趣的:(pytorch,深度学习,深度学习,pytorch,轻量型模型)