有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码
SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)
class SwinTransformerBlock(nn.Module):
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
SwinTransformerBlock
是 Swin Transformer 模型中的一个基本构建块。它结合了自注意力机制和多层感知机(MLP),并通过窗口划分和可选的窗口位移来实现局部注意力
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
dim
:输入特征的通道数。input_resolution
:输入特征的分辨率(高度和宽度)num_heads
:自注意力头的数量window_size
:窗口大小,决定了注意力机制的局部范围shift_size
:窗口位移的大小,用于实现错位窗口多头自注意力(SW-MSA)mlp_ratio
:MLP隐层大小与输入通道数的比率qkv_bias
:QKV的偏置qk_scale
:QKV的缩放因子drop
:丢弃率drop_path
:分别控制QKV的偏差、缩放因子、丢弃率、注意力丢弃率和随机深度率norm_layer
:激活层和标准化层,默认分别为 GELU 和 LayerNormWindowAttention
:窗口注意力模块Mlp
:一个包含全连接层、激活函数、Dropout的模块img_mask
:图像掩码,用于生成错位窗口自注意力h_slices
和w_slices
:水平和垂直方向上的切片,用于划分图像掩码cnt
:计数器,标记不同的窗口mask_windows
:图像掩码划分为窗口,并将每个窗口的掩码重塑为一维向量window_partition
attn_mask
:注意力掩码,用于在自注意力计算中排除窗口外的位置register_buffer
:注意力掩码注册为一个模型的缓冲区def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
x_windows = window_partition(shifted_x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
而Windows Attention就是在这7*7里面的窗口进行计算
SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)