有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码
SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
img_size
定义了输入图像的尺寸,默认为 224x224 像素。patch_size
定义了每个patch的大小,默认为 4x4 像素,这意味着每个patch包含 4x4=16 个像素。in_chans
指定了输入图像的通道数,默认为 3,对应于常见的 RGB 图像。embed_dim
定义了线性投影的输出通道数,即每个patch的嵌入维度,默认为 96。norm_layer
是一个可选的标准化层,用于在嵌入后应用标准化。self.proj
是一个卷积层,用于将输入图像的每个patch转换成嵌入向量。其使用了与patch大小相同的卷积核和步长,确保图像被分割成不重叠的patch。def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
patch_embbeding,主要通过Swin_Transformer.py的PatchEmbed类实现
对PatchEmbed类的前向传播进行debug:
self.proj
卷积层将图像分割成patch并进行线性投影。操作后,张量被展平并转置,以匹配 (B, N, C)
的形状,其中 N
是patch数量,C
是嵌入维度。3136=56*566,56=224/4卷积得到norm_layer
,则在嵌入向量上应用标准化class BasicLayer(nn.Module):
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
主要参数解析:
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
BasicLayer类的实例,在SwinTransformer类中被一个for循环多次调用,因此每次调用的情况会产生变化
原始输入x: torch.Size([4, 3136, 96])
blb(x): torch.Size([4, 3136, 96])
blb(x): torch.Size([4, 3136, 96])
torch.Size([4, 784, 192])
原始输入x: torch.Size([4, 784, 192])
blb(x): torch.Size([4, 784, 192])
blb(x): torch.Size([4, 784, 192])
torch.Size([4, 196, 384])
原始输入x: torch.Size([4, 196, 384])
blb(x): torch.Size([4, 196, 384])
blb(x): torch.Size([4, 196, 384])
blb(x): torch.Size([4, 196, 384])
blb(x): torch.Size([4, 196, 384])
blb(x): torch.Size([4, 196, 384])
blb(x): torch.Size([4, 196, 384])
torch.Size([4, 49, 768])
原始输入x: torch.Size([4, 49, 768])
blb(x): torch.Size([4, 49, 768])
blb(x): torch.Size([4, 49, 768])
torch.Size([4, 49, 768])
每4次一个循环,最开始的3136是序列长度,96是每个向量的维度,序列长度不断变短,而向量维度在增加,这个变化主要是在后面的下采样中产生的变化,在循环多次调用SwinTransformerBlock中维度没有产生变化
BasicLayer
提供了 Swin Transformer 模型中一个阶段的完整实现,包括自注意力块的堆叠、可选的下采样处理,以及对应的配置选项,如自注意力的头数、窗口大小等。通过这种模块化的设计,Swin Transformer 能够灵活地适应不同的任务和数据集,同时保持较高的计算效率和表示能力
SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)