论文详细讲解如下:
【论文笔记】Segformer论文阅读笔记_嘟嘟太菜了的博客-CSDN博客
论文笔记——Segformer: 一种基于Transformer的语义分割方法 - 知乎
上图上segformer的整体结构
----------------------------------------------------------------------------------------------------------------------
segformer = backbone(mit) + segformerHead
1、backbone(mit):
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from semseg.models.layers import DropPath
"""
Attention模块计算注意力
例如:
输入:x:[1,3136,32]
输出:x:[1,3136,32]
"""
class Attention(nn.Module):
# dim=32,head=1,sr_ratio=8
def __init__(self, dim, head, sr_ratio):
super().__init__()
self.head = head
self.sr_ratio = sr_ratio # self.sr_ratio = 8
self.scale = (dim // head) ** -0.5 # self.scale 用于将q * K的转置的结果 * self.scale
self.q = nn.Linear(dim, dim)
self.kv = nn.Linear(dim, dim*2)
self.proj = nn.Linear(dim, dim)
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, sr_ratio, sr_ratio)
self.norm = nn.LayerNorm(dim)
# x:[1,3136,32] 输入
def forward(self, x: Tensor, H, W) -> Tensor:
B, N, C = x.shape # B 为batch_size ,N为每个patch的长*宽,C为通道数
# q:[1,3136,32] q的维度和输入x保持一致
q = self.q(x)
# q:[1,1,3136,32]
q = q.reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
# x:[1,32,56,56]
x = x.permute(0, 2, 1).reshape(B, C, H, W)
# x:[1,32,7,7]
x = self.sr(x)
# x:[1,49,32]
x = x.reshape(B, C, -1).permute(0, 2, 1)
# x:[1,49,32]
x = self.norm(x)
# kv:[1,49,64]
kv = self.kv(x)
# k:[1,1,49,32]
# v:[1,1,49,32]
k, v = kv.reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)
# @矩阵相乘运算符
# attn:[1,1,3136,49]
attn = (q @ k.transpose(-2, -1)) * self.scale
# attn:[1,1,3136,49]
attn = attn.softmax(dim=-1)
# x:[1,3136,32]
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
# x:[1,3136,32]
x = self.proj(x)
return x
"""
深度可分离卷积
输入:[1,3136,128]
输出:[1,3136,32]
"""
class DWConv(nn.Module):
def __init__(self, dim):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
# x:[1,3136,128]
def forward(self, x: Tensor, H, W) -> Tensor:
B, _, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W) # x:[1,128,56,56]
x = self.dwconv(x) # x:[1,128,56,56]
return x.flatten(2).transpose(1, 2)
"""
MLP多层感知机
输入:[1,3136,32]
输出:[1,3136,32]
"""
class MLP(nn.Module):
def __init__(self, c1, c2):
super().__init__()
self.fc1 = nn.Linear(c1, c2)
self.dwconv = DWConv(c2)
self.fc2 = nn.Linear(c2, c1)
# def forward(self, x: Tensor, H, W) -> Tensor:
# return self.fc2(F.gelu(self.dwconv(self.fc1(x), H, W)))
# x:[1,3136,32]
def forward(self, x: Tensor, H, W) -> Tensor:
x = self.fc1(x) # x:[1,3136,128]
x = self.dwconv(x, H, W) # [1,3136,32]
x = self.fc2(F.gelu(x))
return x
"""
功能:通过conv操作将图像划分为不同的patch,再把维度做一个变换
例如:
输入:[1,3,224,224]
过程:[1,3,224,224] --conv--> [1,32,56,56] --flatten,transpose--> [1,3136,32],其中3136=56*56
输出:[1,3136,32]
"""
class PatchEmbed(nn.Module):
def __init__(self, c1=3, c2=32, patch_size=7, stride=4):
super().__init__()
self.proj = nn.Conv2d(c1, c2, patch_size, stride, patch_size//2) # padding=(ps[0]//2, ps[1]//2)
self.norm = nn.LayerNorm(c2)
# x:[1,3,224,224]
def forward(self, x: Tensor) -> Tensor:
# x:[1,32,56,56]
x = self.proj(x)
_, _, H, W = x.shape
# x:[1,3136,32]
x = x.flatten(2).transpose(1, 2)
# x:[1,3136,32]
x = self.norm(x)
return x, H, W
"""
Block为Mit模型的基本结构
输入:[1,3136,32]
输出:[1,3136,32]
"""
class Block(nn.Module):
def __init__(self, dim, head, sr_ratio=1, dpr=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, head, sr_ratio)
self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity()
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, int(dim*4))
# def forward(self, x: Tensor, H, W) -> Tensor:
# x = x + self.drop_path(self.attn(self.norm1(x), H, W))
# x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
# return x
# x:[1,3136,32]
def forward(self, x: Tensor, H, W) -> Tensor:
# Efficient Self-Atten 模块
x1 = self.norm1(x) # x1:[1,3136,32]
x1 = self.attn(x1, H, W) # x1:[1,3136,32] 此模块计算注意力
x1 = self.drop_path(x1) # x1:[1,3136,32] drop_path作用:若x为输入的张量,其通道为[B,C,H,W],那么drop_path的含义为在一个batch_size中,随机有dpr的样本,不经过主干,而直接由分支进行恒等映射
x = x + x1
# Mix-FFN模块
x2 = self.norm2(x)
x2 = self.mlp(x2, H, W) # x2:[1,3136,32] 多层感知机
x2 = self.drop_path(x2)
x = x + x2
return x
mit_settings = {
'B0': [[32, 64, 160, 256], [2, 2, 2, 2]], # [embed_dims, depths]
'B1': [[64, 128, 320, 512], [2, 2, 2, 2]],
'B2': [[64, 128, 320, 512], [3, 4, 6, 3]],
'B3': [[64, 128, 320, 512], [3, 4, 18, 3]],
'B4': [[64, 128, 320, 512], [3, 8, 27, 3]],
'B5': [[64, 128, 320, 512], [3, 6, 40, 3]]
}
"""
Mit为SegFormer的Encoder部分,主要用于特征提取;
输入:[1,3,224,224]
输出:[1,32,56,56],[1,64,28,28],[1,160,14,14],[1,256,7,7] 多尺度信息
"""
class MiT(nn.Module):
def __init__(self, model_name: str = 'B0'):
super().__init__()
assert model_name in mit_settings.keys(), f"MiT model name should be in {list(mit_settings.keys())}" # model name 必须在mit_settings中
embed_dims, depths = mit_settings[model_name] # model name:B0, depths:[2,2,2,2],embed_dims:[32,64,160,256]
drop_path_rate = 0.1
self.channels = embed_dims # self.channels:[32,64,160,256]
# patch_embed
self.patch_embed1 = PatchEmbed(3, embed_dims[0], 7, 4)
self.patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2)
self.patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2)
self.patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2)
# 将drop_path_rate等间隔分为sum(depths),即8份
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
# depths[0] = 2,即第一层包含2个block的深度,第一个block(32,1,8,dpr[0]),第二个block(32,1,8,dpr[1])
self.block1 = nn.ModuleList([Block(embed_dims[0], 1, 8, dpr[cur+i]) for i in range(depths[0])])
self.norm1 = nn.LayerNorm(embed_dims[0])
cur += depths[0]
self.block2 = nn.ModuleList([Block(embed_dims[1], 2, 4, dpr[cur+i]) for i in range(depths[1])])
self.norm2 = nn.LayerNorm(embed_dims[1])
cur += depths[1]
self.block3 = nn.ModuleList([Block(embed_dims[2], 5, 2, dpr[cur+i]) for i in range(depths[2])])
self.norm3 = nn.LayerNorm(embed_dims[2])
cur += depths[2]
self.block4 = nn.ModuleList([Block(embed_dims[3], 8, 1, dpr[cur+i]) for i in range(depths[3])])
self.norm4 = nn.LayerNorm(embed_dims[3])
# x:[1,3,224,224]
def forward(self, x: Tensor) -> Tensor:
B = x.shape[0]
# stage 1
# x:[1,3136,32],H:56,W:56
x, H, W = self.patch_embed1(x)
for blk in self.block1:
x = blk(x, H, W)
# x:[1,3136,32]
# x1:[1,32,56,56]
x1 = self.norm1(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)
# stage 2
# x:[1,784,64],H:28,W:28
x, H, W = self.patch_embed2(x1)
for blk in self.block2:
x = blk(x, H, W)
# x:[1,784,64]
# x2:[1,64,28,28]
x2 = self.norm2(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)
# stage 3
# x:[1,196,160],H:14,W:14
x, H, W = self.patch_embed3(x2)
for blk in self.block3:
x = blk(x, H, W)
# x:[1,196,160]
# x3:[1,160,14,14]
x3 = self.norm3(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)
# stage 4
# x3:[1,160,14,14]
# x:[1,49,256]
x, H, W = self.patch_embed4(x3)
for blk in self.block4:
x = blk(x, H, W)
# x:[1,49,256]
# x4:[1,256,7,7]
x4 = self.norm4(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)
return x1, x2, x3, x4
if __name__ == '__main__':
model = MiT('B0')
x = torch.zeros(1, 3, 224, 224)
outs = model(x)
for y in outs:
print(y.shape)
2、segformerHead:
import torch
from torch import nn, Tensor
from typing import Tuple
from torch.nn import functional as F
class MLP(nn.Module):
def __init__(self, dim, embed_dim):
super().__init__()
self.proj = nn.Linear(dim, embed_dim)
def forward(self, x: Tensor) -> Tensor:
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class ConvModule(nn.Module):
def __init__(self, c1, c2):
super().__init__()
self.conv = nn.Conv2d(c1, c2, 1, bias=False)
self.bn = nn.BatchNorm2d(c2) # use SyncBN in original
self.activate = nn.ReLU(True)
def forward(self, x: Tensor) -> Tensor:
return self.activate(self.bn(self.conv(x)))
class SegFormerHead(nn.Module):
def __init__(self, dims: list, embed_dim: int = 256, num_classes: int = 19):
super().__init__()
for i, dim in enumerate(dims):
self.add_module(f"linear_c{i+1}", MLP(dim, embed_dim))
self.linear_fuse = ConvModule(embed_dim*4, embed_dim)
self.linear_pred = nn.Conv2d(embed_dim, num_classes, 1)
self.dropout = nn.Dropout2d(0.1)
def forward(self, features: Tuple[Tensor, Tensor, Tensor, Tensor]) -> Tensor:
B, _, H, W = features[0].shape
outs = [self.linear_c1(features[0]).permute(0, 2, 1).reshape(B, -1, *features[0].shape[-2:])]
for i, feature in enumerate(features[1:]):
cf = eval(f"self.linear_c{i+2}")(feature).permute(0, 2, 1).reshape(B, -1, *feature.shape[-2:])
outs.append(F.interpolate(cf, size=(H, W), mode='bilinear', align_corners=False))
seg = self.linear_fuse(torch.cat(outs[::-1], dim=1))
seg = self.linear_pred(self.dropout(seg))
return
if __name__ == "__main__":
model = SegFormerHead([32,64,160,256],256,19)
x1 = torch.zeros([1,32,56,56])
x2 = torch.zeros([1,64,28,28])
x3 = torch.zeros([1,160,14,14])
x4 = torch.zeros([1,256,7,7])
features = (x1,x2,x3,x4)
outs = model(features)
print(outs.shape)
3、segformer:https://github.com/sithu31296/semantic-segmentation/blob/main/semseg/models/segformer.py
import torch
from torch import Tensor
from torch.nn import functional as F
from semseg.models.base import BaseModel
from semseg.models.heads import SegFormerHead
class SegFormer(BaseModel):
def __init__(self, backbone: str = 'MiT-B0', num_classes: int = 19) -> None:
super().__init__(backbone, num_classes)
self.decode_head = SegFormerHead(self.backbone.channels, 256 if 'B0' in backbone or 'B1' in backbone else 768, num_classes)
self.apply(self._init_weights)
def forward(self, x: Tensor) -> Tensor:
y = self.backbone(x)
y = self.decode_head(y) # 4x reduction in image size
y = F.interpolate(y, size=x.shape[2:], mode='bilinear', align_corners=False) # to original image shape
return y
if __name__ == '__main__':
model = SegFormer('MiT-B0')
# model.load_state_dict(torch.load('checkpoints/pretrained/segformer/segformer.b0.ade.pth', map_location='cpu'))
x = torch.zeros(1, 3, 512, 512)
y = model(x)
print(y.shape)