Segformer论文研读

论文详细讲解如下:

【论文笔记】Segformer论文阅读笔记_嘟嘟太菜了的博客-CSDN博客

 论文笔记——Segformer: 一种基于Transformer的语义分割方法 - 知乎

Segformer论文研读_第1张图片

上图上segformer的整体结构

----------------------------------------------------------------------------------------------------------------------

 segformer = backbone(mit) + segformerHead

1、backbone(mit):

import torch
from torch import nn, Tensor
from torch.nn import functional as F
from semseg.models.layers import DropPath

"""
Attention模块计算注意力
例如:
输入:x:[1,3136,32]
输出:x:[1,3136,32]
"""
class Attention(nn.Module):
    # dim=32,head=1,sr_ratio=8
    def __init__(self, dim, head, sr_ratio):
        super().__init__()
        self.head = head
        self.sr_ratio = sr_ratio # self.sr_ratio = 8
        self.scale = (dim // head) ** -0.5  # self.scale 用于将q * K的转置的结果 * self.scale
        self.q = nn.Linear(dim, dim)
        self.kv = nn.Linear(dim, dim*2)
        self.proj = nn.Linear(dim, dim)

        if sr_ratio > 1:
            self.sr = nn.Conv2d(dim, dim, sr_ratio, sr_ratio)
            self.norm = nn.LayerNorm(dim)

    # x:[1,3136,32] 输入
    def forward(self, x: Tensor, H, W) -> Tensor:
        B, N, C = x.shape  # B 为batch_size ,N为每个patch的长*宽,C为通道数
        # q:[1,3136,32] q的维度和输入x保持一致
        q = self.q(x)
        # q:[1,1,3136,32]
        q = q.reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)

        if self.sr_ratio > 1:
            # x:[1,32,56,56]
            x = x.permute(0, 2, 1).reshape(B, C, H, W)
            # x:[1,32,7,7]
            x = self.sr(x)
            # x:[1,49,32]
            x = x.reshape(B, C, -1).permute(0, 2, 1)
            # x:[1,49,32]
            x = self.norm(x)

        # kv:[1,49,64]
        kv = self.kv(x)
        # k:[1,1,49,32]
        # v:[1,1,49,32]
        k, v = kv.reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)
        # @矩阵相乘运算符
        # attn:[1,1,3136,49]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        # attn:[1,1,3136,49]
        attn = attn.softmax(dim=-1)
        # x:[1,3136,32]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        # x:[1,3136,32]
        x = self.proj(x)
        return x

"""
深度可分离卷积
输入:[1,3136,128]
输出:[1,3136,32]
"""
class DWConv(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)

    # x:[1,3136,128]
    def forward(self, x: Tensor, H, W) -> Tensor:
        B, _, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W) # x:[1,128,56,56]
        x = self.dwconv(x) # x:[1,128,56,56]
        return x.flatten(2).transpose(1, 2)

"""
MLP多层感知机
输入:[1,3136,32]
输出:[1,3136,32]
"""
class MLP(nn.Module):
    def __init__(self, c1, c2):
        super().__init__()
        self.fc1 = nn.Linear(c1, c2)
        self.dwconv = DWConv(c2)
        self.fc2 = nn.Linear(c2, c1)
        
    # def forward(self, x: Tensor, H, W) -> Tensor:
    #     return self.fc2(F.gelu(self.dwconv(self.fc1(x), H, W)))

    # x:[1,3136,32]
    def forward(self, x: Tensor, H, W) -> Tensor:
        x = self.fc1(x) # x:[1,3136,128]
        x = self.dwconv(x, H, W) # [1,3136,32]
        x = self.fc2(F.gelu(x))
        return x

"""
功能:通过conv操作将图像划分为不同的patch,再把维度做一个变换
例如:
输入:[1,3,224,224]
过程:[1,3,224,224] --conv--> [1,32,56,56] --flatten,transpose--> [1,3136,32],其中3136=56*56
输出:[1,3136,32]
"""
class PatchEmbed(nn.Module):
    def __init__(self, c1=3, c2=32, patch_size=7, stride=4):
        super().__init__()
        self.proj = nn.Conv2d(c1, c2, patch_size, stride, patch_size//2)    # padding=(ps[0]//2, ps[1]//2)
        self.norm = nn.LayerNorm(c2)

    # x:[1,3,224,224]
    def forward(self, x: Tensor) -> Tensor:
        # x:[1,32,56,56]
        x = self.proj(x)
        _, _, H, W = x.shape
        # x:[1,3136,32]
        x = x.flatten(2).transpose(1, 2)
        # x:[1,3136,32]
        x = self.norm(x)
        return x, H, W


"""
Block为Mit模型的基本结构
输入:[1,3136,32]
输出:[1,3136,32]
"""
class Block(nn.Module):
    def __init__(self, dim, head, sr_ratio=1, dpr=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, head, sr_ratio)
        self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, int(dim*4))

    # def forward(self, x: Tensor, H, W) -> Tensor:
    #     x = x + self.drop_path(self.attn(self.norm1(x), H, W))
    #     x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
    #     return x

    # x:[1,3136,32]
    def forward(self, x: Tensor, H, W) -> Tensor:
        # Efficient Self-Atten 模块
        x1 = self.norm1(x) # x1:[1,3136,32]
        x1 = self.attn(x1, H, W) # x1:[1,3136,32] 此模块计算注意力
        x1 = self.drop_path(x1) # x1:[1,3136,32] drop_path作用:若x为输入的张量,其通道为[B,C,H,W],那么drop_path的含义为在一个batch_size中,随机有dpr的样本,不经过主干,而直接由分支进行恒等映射
        x = x + x1

        # Mix-FFN模块
        x2 = self.norm2(x)
        x2 = self.mlp(x2, H, W) # x2:[1,3136,32] 多层感知机
        x2 = self.drop_path(x2)
        x = x + x2
        return x


mit_settings = {
    'B0': [[32, 64, 160, 256], [2, 2, 2, 2]],        # [embed_dims, depths]
    'B1': [[64, 128, 320, 512], [2, 2, 2, 2]],
    'B2': [[64, 128, 320, 512], [3, 4, 6, 3]],
    'B3': [[64, 128, 320, 512], [3, 4, 18, 3]],
    'B4': [[64, 128, 320, 512], [3, 8, 27, 3]],
    'B5': [[64, 128, 320, 512], [3, 6, 40, 3]]
}


"""
    Mit为SegFormer的Encoder部分,主要用于特征提取;
    输入:[1,3,224,224]
    输出:[1,32,56,56],[1,64,28,28],[1,160,14,14],[1,256,7,7]  多尺度信息
"""
class MiT(nn.Module):
    def __init__(self, model_name: str = 'B0'):
        super().__init__()
        assert model_name in mit_settings.keys(), f"MiT model name should be in {list(mit_settings.keys())}" # model name 必须在mit_settings中
        embed_dims, depths = mit_settings[model_name]  # model name:B0, depths:[2,2,2,2],embed_dims:[32,64,160,256]
        drop_path_rate = 0.1
        self.channels = embed_dims  # self.channels:[32,64,160,256]

        # patch_embed
        self.patch_embed1 = PatchEmbed(3, embed_dims[0], 7, 4)
        self.patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2)
        self.patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2)
        self.patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2)

        # 将drop_path_rate等间隔分为sum(depths),即8份
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        
        cur = 0
        # depths[0] = 2,即第一层包含2个block的深度,第一个block(32,1,8,dpr[0]),第二个block(32,1,8,dpr[1])
        self.block1 = nn.ModuleList([Block(embed_dims[0], 1, 8, dpr[cur+i]) for i in range(depths[0])])
        self.norm1 = nn.LayerNorm(embed_dims[0])

        cur += depths[0]
        self.block2 = nn.ModuleList([Block(embed_dims[1], 2, 4, dpr[cur+i]) for i in range(depths[1])])
        self.norm2 = nn.LayerNorm(embed_dims[1])

        cur += depths[1]
        self.block3 = nn.ModuleList([Block(embed_dims[2], 5, 2, dpr[cur+i]) for i in range(depths[2])])
        self.norm3 = nn.LayerNorm(embed_dims[2])

        cur += depths[2]
        self.block4 = nn.ModuleList([Block(embed_dims[3], 8, 1, dpr[cur+i]) for i in range(depths[3])])
        self.norm4 = nn.LayerNorm(embed_dims[3])

    # x:[1,3,224,224]
    def forward(self, x: Tensor) -> Tensor:
        B = x.shape[0]
        # stage 1
        # x:[1,3136,32],H:56,W:56
        x, H, W = self.patch_embed1(x)
        for blk in self.block1:
            x = blk(x, H, W)
        # x:[1,3136,32]
        # x1:[1,32,56,56]
        x1 = self.norm1(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)

        # stage 2
        # x:[1,784,64],H:28,W:28
        x, H, W = self.patch_embed2(x1)
        for blk in self.block2:
            x = blk(x, H, W)
        # x:[1,784,64]
        # x2:[1,64,28,28]
        x2 = self.norm2(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)

        # stage 3
        # x:[1,196,160],H:14,W:14
        x, H, W = self.patch_embed3(x2)
        for blk in self.block3:
            x = blk(x, H, W)
        # x:[1,196,160]
        # x3:[1,160,14,14]
        x3 = self.norm3(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)

        # stage 4
        # x3:[1,160,14,14]
        # x:[1,49,256]
        x, H, W = self.patch_embed4(x3)
        for blk in self.block4:
            x = blk(x, H, W)
        # x:[1,49,256]
        # x4:[1,256,7,7]
        x4 = self.norm4(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)

        return x1, x2, x3, x4


if __name__ == '__main__':
    model = MiT('B0')
    x = torch.zeros(1, 3, 224, 224)
    outs = model(x)
    for y in outs:
        print(y.shape)
        

2、segformerHead:

import torch
from torch import nn, Tensor
from typing import Tuple
from torch.nn import functional as F


class MLP(nn.Module):
    def __init__(self, dim, embed_dim):
        super().__init__()
        self.proj = nn.Linear(dim, embed_dim)

    def forward(self, x: Tensor) -> Tensor:
        x = x.flatten(2).transpose(1, 2)
        x = self.proj(x)
        return x


class ConvModule(nn.Module):
    def __init__(self, c1, c2):
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, 1, bias=False)
        self.bn = nn.BatchNorm2d(c2)        # use SyncBN in original
        self.activate = nn.ReLU(True)

    def forward(self, x: Tensor) -> Tensor:
        return self.activate(self.bn(self.conv(x)))


class SegFormerHead(nn.Module):
    def __init__(self, dims: list, embed_dim: int = 256, num_classes: int = 19):
        super().__init__()
        for i, dim in enumerate(dims):
            self.add_module(f"linear_c{i+1}", MLP(dim, embed_dim))

        self.linear_fuse = ConvModule(embed_dim*4, embed_dim)
        self.linear_pred = nn.Conv2d(embed_dim, num_classes, 1)
        self.dropout = nn.Dropout2d(0.1)

    def forward(self, features: Tuple[Tensor, Tensor, Tensor, Tensor]) -> Tensor:
        B, _, H, W = features[0].shape
        outs = [self.linear_c1(features[0]).permute(0, 2, 1).reshape(B, -1, *features[0].shape[-2:])]

        for i, feature in enumerate(features[1:]):
            cf = eval(f"self.linear_c{i+2}")(feature).permute(0, 2, 1).reshape(B, -1, *feature.shape[-2:])
            outs.append(F.interpolate(cf, size=(H, W), mode='bilinear', align_corners=False))

        seg = self.linear_fuse(torch.cat(outs[::-1], dim=1))
        seg = self.linear_pred(self.dropout(seg))
        return 
if __name__ == "__main__":
    model = SegFormerHead([32,64,160,256],256,19)
    x1 = torch.zeros([1,32,56,56])
    x2 = torch.zeros([1,64,28,28])
    x3 = torch.zeros([1,160,14,14])
    x4 = torch.zeros([1,256,7,7])

    features = (x1,x2,x3,x4)
    outs = model(features)
    print(outs.shape)

 3、segformer:https://github.com/sithu31296/semantic-segmentation/blob/main/semseg/models/segformer.py

import torch
from torch import Tensor
from torch.nn import functional as F
from semseg.models.base import BaseModel
from semseg.models.heads import SegFormerHead


class SegFormer(BaseModel):
    def __init__(self, backbone: str = 'MiT-B0', num_classes: int = 19) -> None:
        super().__init__(backbone, num_classes)
        self.decode_head = SegFormerHead(self.backbone.channels, 256 if 'B0' in backbone or 'B1' in backbone else 768, num_classes)
        self.apply(self._init_weights)

    def forward(self, x: Tensor) -> Tensor:
        y = self.backbone(x)
        y = self.decode_head(y)   # 4x reduction in image size
        y = F.interpolate(y, size=x.shape[2:], mode='bilinear', align_corners=False)    # to original image shape
        return y


if __name__ == '__main__':
    model = SegFormer('MiT-B0')
    # model.load_state_dict(torch.load('checkpoints/pretrained/segformer/segformer.b0.ade.pth', map_location='cpu'))
    x = torch.zeros(1, 3, 512, 512)
    y = model(x)
    print(y.shape)

你可能感兴趣的:(算法研读,图像分割,注意力机制,深度学习,计算机视觉)