paper:GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond
official implementaion:https://github.com/xvjiarui/GCNet
Third party implementation:https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/bricks/context_block.py
通过捕获long-range dependency提取全局信息,对各种视觉任务都是很有帮助的。Non-local Network(介绍见https://blog.csdn.net/ooooocj/article/details/124573078)通过自注意力机制来解决这个问题。对于每个查询位置(query position),non-local network首先计算该位置和所有位置之间一个两两成对的关系,得到一个attention map。然后对attention map所有位置的权重加权求和得到汇总特征,每一个查询位置都得到一个汇总特征,将汇总特征与原始特征相加得到最终输出。
对于某个query position,non-local network计算的另一个位置与该位置的关系即一个权重值表示这个位置对query位置的重要程度。本文可视化attention map发现,不同的query位置其对应的attention map几乎一样,如下图所示
non-local block可以表示为下式
其中 \(i\) 是query position的索引,\(j\) 遍历所有位置,\(f(\mathbf{x}_{i},\mathbf{x}_{j})\) 表示位置 \(i\) 和 \(j\) 之间的关系,\(\mathcal{C}(\mathbf{x})\) 是归一化因子,\(W_{z}\) 和 \(W_{v}\) 是线性变换矩阵例如 \(1\times1\) 卷积。non-local block有多种不同的实例化方法,例如Gaussian、Embedded Gaussian、Dot product、Concat,下图(a)是Embedded Gaussian的结构。
由于需要计算每个query位置的attention map,因此non-local block的时间和空间复杂度都是所有位置的平方关系。
下图是作者从COCO数据集中随机挑选的6张图片,并可视化出3个不同的query position即图中的红点与对应的query-specific attention map,可以看出对于不同的查询位置,它们的attention map几乎是相同的。
为了进一步验证这一观察结果,作者又分析了不同的查询位置与全局上下文之间的距离。结果如下表所示。其中计算了三种向量之间的余弦距离cosine distance,分别为non-local block的输入、输出、以及查询的注意力图,对应表中的input、output、att。
从表中可以看出,input列的余弦距离比较大表明non-local的输入特征可以在不同的位置进行区分,但output列的余弦距离非常小,表明non-local block建模的全局上下文特征对于不同的query position几乎是相同的。attention map上的距离非常小,也验证了可视化的观察结果。
尽管non-local block是打算针对每个位置计算全局上下文的,但训练后的全局上下文实际上是独立于查询位置的。因为没有必要为每个查询位置单独计算query-specific全局上下文。
本文通过观察发现non-local block针对每个query position计算的attention map最终结果是独立于查询位置的,那么就没有必要针对每个查询位置计算了,因此提出计算一个通用的attention map并应用于输入feature map上的所有位置,大大减少了计算量的同时又没有导致性能的降低。此外,结合SE block,设计了一个新的Global Context (GC) block,既轻量又可以有效地建模全局上下文。GC Block结合了Non-local block和SE block的优点,基于GC Block设计的GCNet在多个任务上均超过了NLNet和SENet。
作者舍去了式(1)中的 \(W_{z}\) 即图(3)(a)中的query分支,得到下式
这里采用的是最常用的Embedded Gaussiian的实例化方式。简化后的non-local block如图(3)(b)所示。
为了进一步减少计算量,作者应用分配率将 \(W_{v}\) 移到attention pool的外面,如下
这里简化后的non-local block如图4(b)所示。1x1卷积 \(W_{v}\) 的FLOPs从 \(\mathcal{O}(HWC^{2})\) 减小到 \(\mathcal{O}(C^{2})\)。
到目前为止简化的NL block中参数量最大的部分在transform module即图4(b)中的Transform部分,这里是一个1x1卷积但参数量为 \(C\cdot C\),当把Nl block应用到较深的层例如resnet中的 \(res_{5}\) 时,CxC=2028x2048,占据了整个block大部分的计算量。为了进一步减少计算量,作者借鉴了SE block的思想如图4(c),将图4(b)中的transform module换成了图4(d)中的bottleneck transform module,其中 \(r\) 是reduction ratio,这样参数量就从C·C变成了2·C·C/r,默认情况下r=16,因此参数量就减少为1/8。
下表baseline是backbone为ResNet-50的Mask R-CNN在COCO数据集上的目标检测和实例分割的结果。将1个non-local block(NL)、1个simplified non-local block(SNL)、1个global context block(GC)插入到c4的最后一个residual block前,可以看出GC block获得的相似的性能但参数量更小。将GC block添加到所有的residual block中在参数量相似的情况下得到了更高的性能。
这里的代码是mmcv中的实现。
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
import torch
from torch import nn
from ..utils import constant_init, kaiming_init
from .registry import PLUGIN_LAYERS
def last_zero_init(m: Union[nn.Module, nn.Sequential]) -> None:
if isinstance(m, nn.Sequential):
constant_init(m[-1], val=0)
else:
constant_init(m, val=0)
@PLUGIN_LAYERS.register_module()
class ContextBlock(nn.Module):
"""ContextBlock module in GCNet.
See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond'
(https://arxiv.org/abs/1904.11492) for details.
Args:
in_channels (int): Channels of the input feature map.
ratio (float): Ratio of channels of transform bottleneck
pooling_type (str): Pooling method for context modeling.
Options are 'att' and 'avg', stand for attention pooling and
average pooling respectively. Default: 'att'.
fusion_types (Sequence[str]): Fusion method for feature fusion,
Options are 'channels_add', 'channel_mul', stand for channelwise
addition and multiplication respectively. Default: ('channel_add',)
"""
_abbr_ = 'context_block'
def __init__(self,
in_channels: int,
ratio: float,
pooling_type: str = 'att',
fusion_types: tuple = ('channel_add', )):
super().__init__()
assert pooling_type in ['avg', 'att']
assert isinstance(fusion_types, (list, tuple))
valid_fusion_types = ['channel_add', 'channel_mul']
assert all([f in valid_fusion_types for f in fusion_types])
assert len(fusion_types) > 0, 'at least one fusion should be used'
self.in_channels = in_channels
self.ratio = ratio
self.planes = int(in_channels * ratio)
self.pooling_type = pooling_type
self.fusion_types = fusion_types
if pooling_type == 'att':
self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1)
self.softmax = nn.Softmax(dim=2)
else:
self.avg_pool = nn.AdaptiveAvgPool2d(1)
if 'channel_add' in fusion_types:
self.channel_add_conv = nn.Sequential(
nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(inplace=True), # yapf: disable
nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
else:
self.channel_add_conv = None
if 'channel_mul' in fusion_types:
self.channel_mul_conv = nn.Sequential(
nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(inplace=True), # yapf: disable
nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
else:
self.channel_mul_conv = None
self.reset_parameters()
def reset_parameters(self):
if self.pooling_type == 'att':
kaiming_init(self.conv_mask, mode='fan_in')
self.conv_mask.inited = True
if self.channel_add_conv is not None:
last_zero_init(self.channel_add_conv)
if self.channel_mul_conv is not None:
last_zero_init(self.channel_mul_conv)
def spatial_pool(self, x: torch.Tensor) -> torch.Tensor:
batch, channel, height, width = x.size()
if self.pooling_type == 'att':
input_x = x
# [N, C, H * W]
input_x = input_x.view(batch, channel, height * width)
# [N, 1, C, H * W]
input_x = input_x.unsqueeze(1)
# [N, 1, H, W]
context_mask = self.conv_mask(x)
# [N, 1, H * W]
context_mask = context_mask.view(batch, 1, height * width)
# [N, 1, H * W]
context_mask = self.softmax(context_mask)
# [N, 1, H * W, 1]
context_mask = context_mask.unsqueeze(-1)
# [N, 1, C, 1]
context = torch.matmul(input_x, context_mask)
# [N, C, 1, 1]
context = context.view(batch, channel, 1, 1)
else:
# [N, C, 1, 1]
context = self.avg_pool(x)
return context
def forward(self, x: torch.Tensor) -> torch.Tensor:
# [N, C, 1, 1]
context = self.spatial_pool(x)
out = x
if self.channel_mul_conv is not None:
# [N, C, 1, 1]
channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
out = out * channel_mul_term
if self.channel_add_conv is not None:
# [N, C, 1, 1]
channel_add_term = self.channel_add_conv(context)
out = out + channel_add_term
return out