Liu, Ze, et al. “Swin transformer: Hierarchical vision transformer using shifted windows.” Proceedings of the IEEE/CVF International Conference on Computer Vision. 2021.
本文是一篇奠定了Transformer在图像领域地位的论文,它不同于ViT(Vision Transformer),提出了一种层次化的结构,因为ViT一开始就固定了patch的划分,因此感受野不会变化,而Swin Transformer采用了传统CNN下采样的设计,在不同的阶段采用不同的感受野尺度,最终得到了比ViT更好的性能表现。
论文代码提供了一种用卷积来进行初始划分patch的方法,就是用kernerl_size,stride与patch_size的卷积核做卷积操作。
class PatchEmbed(nn.Module):
def __init__(self, patch_size=4, in_c=1, embed_dim=96, norm_layer=None):
super(PatchEmbed, self).__init__()
patch_size = (patch_size, patch_size)
self.patch_size = patch_size
self.in_channels = in_c
self.embed_dim = embed_dim
self.proj = nn.Conv2d(
in_channels = in_c,
out_channels = embed_dim,
kernel_size=patch_size,
stride= patch_size
) # 用卷积做patch的划分,kernel_size和stride一致即可
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
# x [batch_size, c, h, w]
_, _, H, W = x.shape
# 若H,W不是patch_size的整数倍,则进行填充
pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
# pad函数的作用是填充图像,pad(input, tuple)
# input 输入的图像
# tuple 例如(1, 2) 最后一维左边填充1列,右边填充2列 (1, 2, 3, 4) w左边填充1列,右边填充2列,h上边填充3行,下边填充4行
if pad_input:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
0, self.patch_size[0] - H % self.patch_size[0], 0, 0))
x = self.proj(x) # [batch_size, embed_dim, h//patch_size, w//patch_size]
_, _, H, W = x.shape # H,W为feature map的高宽
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
# flatten: [B, C, H, W] -> [B, C, HW]
# transpose: [B, C, HW] -> [B, HW, C]
# 代表了最终的输出是 HW个patch,每个patch的通道数是embed_dim
return x, H, W
http://www.manongjc.com/detail/24-jjgknxdkdzormze.html
论文代码中为了减少过拟合的影响,引入了DropPath的方法,具体参考上面的链接,代码如下:
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
DropPath是将深度学习模型中的多分支结构随机”删除“
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path_f(x, self.drop_prob, self.training)
def drop_path_f(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
Swin Transformer和ViT最大的不同就是引入了窗口的概念,对一个窗口中的像素/patch做自注意力,而不是整张图片所有的像素/patch做自注意力,因此计算效率更高。
def window_partition(x,window_size):
"""
划分Feature Map, 划分成一个个没有重叠的Window;
这个window_partition与ViT Patch的划分方法如出一辙;
若干patch组合成一个window
:param x: (B,H,W,C)
:param window_size: (M)
:return: windows: (num_windows*B, window_size, window_size, C)
"""
B,H,W,C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
# permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
# view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) # 使得一个window中的patch做MSA
return windows
def window_reverse(windows, window_size: int, H: int, W: int):
"""
将一个个window还原成一个feature map
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size(M)
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
# view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
# permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
# view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
https://blog.csdn.net/qq_37541097/article/details/121119988?spm=1001.2014.3001.5502
上面这篇博文对patchmerge概括的很好,具体来说就是做了一个跟cnn类似的事情,通道翻倍,宽高减半,从而可以替代cnn。文章里的patch划分基本就是在做这件事,因此要和window_size区分开。
class PatchMerging(nn.Module):
"""
用来在每个Stage开始前进行DownSample,以缩小分辨率,并调整通道数量,以达到分层和高效的作用。
- 类似于CNN内,通过调整Stride来降低分辨率的作用。
Step1: 行列间隔2选取元素
Step2: 拼接为一整个Tensor(通道数变为4倍)
Step3: 通过FC Layer调整通道数
"""
def __init__(self,dim,norm_layer=nn.LayerNorm):
# dim : 输入的通道数
super(PatchMerging, self).__init__()
self.dim = dim
# 4倍通道->2倍通道
self.reduction = nn.Linear(4*dim,2*dim,bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
"""
x: B, H*W, C
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
# 非2整数倍,需要对Feature Map进行padding operation
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
# to pad the last 3 dimensions, starting from the last dimension and moving forward.
# (C_front, C_back, W_left, W_right, H_top, H_bottom)
# 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) # 3维扩充
# 间隔2取元素
x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]
x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]
x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]
x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]
x = torch.cat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C]
x = x.view(B, -1, 4 * C) # [B, H/2*W/2, 4*C]
x = self.norm(x)
# 降低通道数
x = self.reduction(x) # [B, H/2*W/2, 2*C]
return x
文章采用了两种window attention的方式:W-MSA和SW-MSA,前者很容易,就是简单的在每个窗口内做attention就可以了,但是这种方式每个窗口之间都是孤立的,没有信息的交互,为了解决这一问题,作者提出了SW-MSA,采用一个shifted window使得窗口之间能够有信息的交互,W-MSA和SW-MSA是成对出现,图中左侧是W-MSA右侧是SW-MSA,可以看到做完shifted window之后,图片变成了9块,因此需要对每一块都做一次MSA,这带来了很大的计算开销,所以作者又提出了一种高效的计算方式,即将图片通过一系列变换,从而变成W-MSA那样的块数。
论文代码中还采用了一种Relative Position Bias的技术,这个技术能够一定程度上提高一些精度,但是文章中并未有文字阐述。具体可以参考:
https://blog.csdn.net/qq_37541097/article/details/121119988?spm=1001.2014.3001.5502
class WindowsAttention(nn.Module):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
"""
在一个window中做MSA
:param dim: 输入通道数
:param window_size: 窗口尺寸
:param num_heads:
:param qkv_bias:
:param attn_drop:
:param proj_drop:
"""
super(WindowsAttention, self).__init__()
self.dim = dim
self.window_size = window_size # [Mh, Mw]
self.num_heads = num_heads
head_dim = dim // num_heads # 每个head的dim
self.scale = head_dim ** -0.5 # scale
# 定义一个parameter table来存放relative position bias
# 相对位置偏置
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # [2*Mh-1 * 2*Mw-1, nH]
# 相对位置索引获得方法
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0]) # [Mh]
coords_w = torch.arange(self.window_size[1]) # [Mw]
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # [2, Mh, Mw]
coords_flatten = torch.flatten(coords, 1) # [2, Mh*Mw]
# [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Mh*Mw, Mh*Mw]
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # [Mh*Mw, Mh*Mw, 2]
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # [Mh*Mw, Mh*Mw]
# Register_buffer: 应该就是在内存中定一个常量,同时,模型保存和加载的时候可以写入和读出。
# 不需要学习,但是可以灵活读写
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self,x,mask=None):
"""
Args:
x: input features with shape of (num_windows*B, Mh*Mw, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
x的输入维度是(num_windows窗口数*Batch Size)
在窗口内进行Attention Op
"""
# [batch_size*num_windows, Mh*Mw, total_embed_dim]
B_, N, C = x.shape
# qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
# reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
# permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
q,k,v = qkv.unbind(0)
# QK^T/sqrt(d)
# transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
# @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
# QK^T/sqrt(d) + B
# B:
# relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # [nH, Mh*Mw, Mh*Mw]
# [Bs*nW, nH, Mh*Mw, Mh*Mw]
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
# SW-MSA 需要做attention Mask
# mask: [nW, Mh*Mw, Mh*Mw]
# attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
# # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
# @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
# transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
# reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
包含了若干对W-MSA和SW-MSA,所以数目必须是偶数。
class SwinTransformerBlock(nn.Module):
"""
Swin Transformer Block包括:
Feature Map Input -> LayerNorm -> SW-MSA/W-MSA -> LayerNorm-> MLP -------->
|--------------------------------------||----------------------|
"""
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
"""
Args参数定义:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
super(SwinTransformerBlock, self).__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
# shift_size必须小于windows_size
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0~window_size"
# LN1
self.norm1 = norm_layer(dim)
# Windows_Multi-head Self Attention
self.attn = WindowsAttention(
dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
# LN2
self.norm2 = norm_layer(dim)
# MLP Layer
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, attn_mask):
# feature map的Height & Width,对应的是之前patch embedding后的输出
H, W = self.H, self.W
# Batch, length, channel
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
# Skip Connect
shortcut = x
x = self.norm1(x)
# reshape feature map
x = x.view(B, H, W, C) # 恢复成feature map
# 对feature map进行pad,pad到windows size的整数倍
pad_l = 0
pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x,(0,0,pad_l,pad_r,pad_t,pad_b))
# Hp, Wp代表pad后的feature map的Height和Width
_, Hp, Wp, _ = x.shape
# 是W-MSA 还是 SW-MSA ?
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) # 向上向左循环移位
else:
shifted_x = x
attn_mask = None
# 窗口划分
# Windows Partition
x_windows = window_partition(shifted_x,self.window_size) #[nW*B, Mh, Mw, C]
x_windows = x_windows.view(-1, self.window_size*self.window_size,C) # [nW*B, Mh*Mw, C]
# W-MSA / SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # [nW*B, Mh*Mw, C]
# 将分割的Windows进行还原
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # [nW*B, Mh, Mw, C]
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # [B, H', W', C]
# 如果是SW-MSA,需要逆shift过程
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
# 移除Pad数据
if pad_r > 0 or pad_b > 0:
# 把前面pad的数据移除掉
x = x[:, :H, :W, :].contiguous()
x = x.view(B,H*W,C)
# FFN
# 两个Skip Connect
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
"""一个Stage内的基本SwinTransformer模块"""
class BasicLayer(nn.Module):
"""
One Stage SwinTransformer Layer包括:
"""
def __init__(self, dim, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
"""
Args:
dim (int): Number of input channels.
depth (int): Number of blocks. block数量
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
super(BasicLayer, self).__init__()
self.dim = dim
self.depth = depth
self.window_size = window_size
self.use_checkpoint = use_checkpoint # pre-trained
self.shift_size = window_size // 2
# 构建SwinTransformer Block
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else self.shift_size, #当i为偶,就是W-MSA,i为奇,就是SW-MSA,与论文一致, 保证窗口之间通信
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# Patch Merging Layer 类似于Pooling下采样
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def create_mask(self,x,H,W):
"""
SW-MSA后,对于移位后左上角的窗口(也就是移位前最中间的窗口)来说,里面的元素都是互相紧挨着的,
他们之间可以互相两两做自注意力,但是对于剩下几个窗口来说,它们里面的元素是从别的很远的地方搬过来的,
所以他们之间,按道理来说是不应该去做自注意力,也就是说他们之间不应该有什么太大的联系
以14x14个patch为例进行
H: Feature Map Height
W: Feature Map Width
x: Feature Map
"""
# 为SW-MSA计算Attention Mask.
# 保证Hp和Wp是window_size的整数倍
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
# 拥有和feature map一样的通道排列顺序,方便后续window_partition
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # [1, Hp, Wp, 1]
# 准备进行区域生成,方便生成Mask
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
# 区域编码
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# Shift Window 混合区域的窗口分割
mask_windows = window_partition(img_mask, self.window_size) # [nW, Mh, Mw, 1]
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # [nW, Mh*Mw]
# 掩码生成
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
# [nW, Mh*Mw, Mh*Mw]
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self,x,H,W):
# [nW, Mh*Mw, Mh*Mw] nW:窗口数
attn_mask = self.create_mask(x,H,W)
for blk in self.blocks:
blk.H, blk.W = H, W # self.H = H, self.W = W
if not torch.jit.is_scripting() and self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, attn_mask)
else:
x = blk(x, attn_mask)
if self.downsample is not None:
x = self.downsample(x, H, W)
H, W = (H + 1) // 2, (W + 1) // 2 # DownSample之后,H,W应该减半
return x, H, W
"""Swin Transformer"""
class SwinTransformer(nn.Module):
"""Swin Transformer结构
这里有个不同之处,就是每个Stage Layer中,
"""
def __init__(self, patch_size=4, in_chans=3, num_classes=1000,
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=nn.LayerNorm, patch_norm=True,
use_checkpoint=False, **kwargs):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
# 输出特征矩阵的Channels (C)
# H/4 x W/4 x 48 -> H/4 x W/4 x C(Stage1) -> H/8 x W/8 x 2C(Stage2) -> H/16 x W/16 x 4C(stage3) ...
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
# 将image切分为不重合的Patches
# input: (Bs, 224, 224, 3)
# output: (e.g patch_size=4: Bs, 56x56, 4x4x3)
self.patch_embed = PatchEmbed(
patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
# Drop Path
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# bulid layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
# 注意这里构建的stage和论文图中有些差异
# 这里的stage不包含该stage的patch_merging层,包含的是下个stage的
layers = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers.append(layers)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self,x):
# x:[B, L, C]
x,H,W = self.patch_embed(x)
x = self.pos_drop(x)
# 多尺度分层Multi-Stage
for layer in self.layers:
x,H,W = layer(x,H,W)
x = self.norm(x) # [B, L, C]
x = self.avgpool(x.transpose(1, 2)) # [B, C, 1]
x = torch.flatten(x, 1)
x = self.head(x) # 分类头
return x