论文:https://arxiv.org/pdf/2209.08575v1.pdf
code:GitHub - Visual-Attention-Network/SegNeXt: Official Pytorch implementations for "SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation" (NeurIPS 2022)
1. 论文主要内容
(1)提高语义分割的方法:
- 优化骨干网络
- 多尺度信息融合
- 空间注意力
- 通道注意力
- 低计算复杂度
Decoder:BackBone: 通用的分类网络: Resnet\RseneXt\DenseNet
语义分割编码器: Res2Net\HrNet\SETR\Seg-former\HR-former\MPViT\DPT
Encoder: 多尺度接收域(FPN?)[94,7,78]、收集多尺度语义[64,80,8]、扩大接收域[5,4,62]、加强边缘特征[95,2,16,42,90]、全局上下文[19,34,89,40,23,26,91]
(2)注意力机制:是一种自适应的选择过程.
<空间注意力> 关注空间区域特征
<**通道注意力**> 关注重要通道对象
SegNeXt:
- 通过多尺度卷积特征唤起**空间注意力**
- 简单和廉价卷积的编码器仍然可以比vit表现得更好,特别是在处理对象细节时,而它需要更少的计算成本。
b)深度卷积来聚合局部信息,多分支深度条形卷积来捕获多尺度上下文,以及1×1卷积来建模不同通道之间的关系。
2. 代码:
(1)构建MSCA
首先通过5*5的卷积提取局部特征
class AttentionModule(BaseModule):
def __init__(self, dim):
super().__init__()
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.conv0_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim)
self.conv0_2 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim)
self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)
self.conv2_1 = nn.Conv2d(
dim, dim, (1, 21), padding=(0, 10), groups=dim)
self.conv2_2 = nn.Conv2d(
dim, dim, (21, 1), padding=(10, 0), groups=dim)
self.conv3 = nn.Conv2d(dim, dim, 1)
def forward(self, x):
u = x.clone()
attn = self.conv0(x) # 5*5局部特征提取
attn_0 = self.conv0_1(attn)
attn_0 = self.conv0_2(attn_0)
attn_1 = self.conv1_1(attn)
attn_1 = self.conv1_2(attn_1)
attn_2 = self.conv2_1(attn)
attn_2 = self.conv2_2(attn_2)
attn = attn + attn_0 + attn_1 + attn_2 # 多尺度特征提取
attn = self.conv3(attn) # 1*1 通道注意力
return attn * u # 返回注意力向量
(2)MSCAN
class SpatialAttention(BaseModule):
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
self.proj_1 = nn.Conv2d(d_model, d_model, 1)
self.activation = nn.GELU()
self.spatial_gating_unit = AttentionModule(d_model)
self.proj_2 = nn.Conv2d(d_model, d_model, 1)
def forward(self, x):
shorcut = x.clone()
x = self.proj_1(x)
x = self.activation(x)
x = self.spatial_gating_unit(x)
x = self.proj_2(x)
x = x + shorcut
return x