【论文系列】SegNeXt

论文:https://arxiv.org/pdf/2209.08575v1.pdf

code:GitHub - Visual-Attention-Network/SegNeXt: Official Pytorch implementations for "SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation" (NeurIPS 2022)

1. 论文主要内容

(1)提高语义分割的方法:

- 优化骨干网络
- 多尺度信息融合
- 空间注意力
- 通道注意力
- 低计算复杂度

Decoder:BackBone: 通用的分类网络: Resnet\RseneXt\DenseNet

​                    语义分割编码器: Res2Net\HrNet\SETR\Seg-former\HR-former\MPViT\DPT

Encoder: 多尺度接收域(FPN?)[94,7,78]、收集多尺度语义[64,80,8]、扩大接收域[5,4,62]、加强边缘特征[95,2,16,42,90]、全局上下文[19,34,89,40,23,26,91]

(2)注意力机制:是一种自适应的选择过程. 

<空间注意力> 关注空间区域特征

<**通道注意力**> 关注重要通道对象

SegNeXt:

- 通过多尺度卷积特征唤起**空间注意力**
- 简单和廉价卷积的编码器仍然可以比vit表现得更好,特别是在处理对象细节时,而它需要更少的计算成本。

【论文系列】SegNeXt_第1张图片

 

b)深度卷积来聚合局部信息,多分支深度条形卷积来捕获多尺度上下文,以及1×1卷积来建模不同通道之间的关系。

2. 代码:

(1)构建MSCA

首先通过5*5的卷积提取局部特征

class AttentionModule(BaseModule):
    def __init__(self, dim):
        super().__init__()
        self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
        self.conv0_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim)
        self.conv0_2 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim)

        self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
        self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)

        self.conv2_1 = nn.Conv2d(
            dim, dim, (1, 21), padding=(0, 10), groups=dim)
        self.conv2_2 = nn.Conv2d(
            dim, dim, (21, 1), padding=(10, 0), groups=dim)
        self.conv3 = nn.Conv2d(dim, dim, 1) 

    def forward(self, x):
        u = x.clone()
        attn = self.conv0(x) # 5*5局部特征提取

        attn_0 = self.conv0_1(attn)
        attn_0 = self.conv0_2(attn_0)

        attn_1 = self.conv1_1(attn)
        attn_1 = self.conv1_2(attn_1)

        attn_2 = self.conv2_1(attn)
        attn_2 = self.conv2_2(attn_2)
        attn = attn + attn_0 + attn_1 + attn_2 # 多尺度特征提取

        attn = self.conv3(attn) # 1*1 通道注意力

        return attn * u # 返回注意力向量

(2)MSCAN

class SpatialAttention(BaseModule):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.proj_1 = nn.Conv2d(d_model, d_model, 1)
        self.activation = nn.GELU()
        self.spatial_gating_unit = AttentionModule(d_model)
        self.proj_2 = nn.Conv2d(d_model, d_model, 1)

    def forward(self, x):
        shorcut = x.clone()
        x = self.proj_1(x)
        x = self.activation(x)
        x = self.spatial_gating_unit(x)
        x = self.proj_2(x)
        x = x + shorcut
        return x

你可能感兴趣的:(深度学习,深度学习,人工智能)