swin transformer 模块理解


前言

【个人学习笔记记录,如有错误,请指正】


配置文件使用 swin_small_patch4_windows7_224.yaml 文件,batch_size = 4

一、Patch Embedding

【Patch embedding】其实就是将输入的 224 * 224 大小的图像,经过【卷积】和【LayerNorm】操作,将图像缩放为 56 56 大小的特征图。然后将特征图reshape 为 (4, 3136, 96)形状,这里的 4 为【batch_size】,3136 = 5656,96 为特征图的通道数。

在这里插入图片描述

二、swin transformer block

【swin transformer block】其实就是下面的流程图。
swin transformer 模块理解_第1张图片
这里主要对 【W-MSA】和 【SW-MSA】进行理解。

1.torch.roll 操作

在进行【roll】操作之前,需要将特征图的形状变为(B, H, W, C),即【4, 56, 56, 96】。

【注意:】,这里的【roll】操作是针对【SW-MSA】才有的。
代码如下:

if self.shift_size > 0:
    shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
    shifted_x = x

示意图如下:
swin transformer 模块理解_第2张图片
源码中,将特征图移动(-3, -3),如上图就是最后的特征图最后的形状。然后将这个新的特征图,进行窗口的划分,然后进行注意力操作,

2. window_partition 操作

代码如下:

x_windows = window_partition(shifted_x, self.window_size)

def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

示意图如下:
swin transformer 模块理解_第3张图片
左边是将特征图划分为7*7 大小的窗口,右边是整个【window_partition】操作的x 的形状变化。

3. W-MSA

将上面的 7 * 7 的特征图,做注意力操作,其中输入的特征图形状为 7 * 7 = 49 和通道数 96,32 是因为多头注意力机制,这里是 3 头注意力机制。

注意力流程示意图如下:
swin transformer 模块理解_第4张图片

代码里的 WindowAttention 的流程大致就是这样子,(这里的位置编码,代码不是很懂,有明白的可以解释解释),其中 mask 机制是在 SW-MSA 中使用到的。

4. SW-MSA

swin transformer 中 为了解决每个窗口之间的交互,引入了对特征图的偏移(偏移量为 3),但是引入偏移之后,源特征图中的窗口数量就变多了,这样就使得计算量变大。这样就引入了 mask 方法。
swin transformer 模块理解_第5张图片

下面是对 mask 掩码生成的过程,这里以特征图大小为 6 * 6 ,窗口大小为 3 * 3,shift 值为 2为例。

原代码做如下修改:

import torch
import matplotlib.pyplot as plt


def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

window_size = 3
shift_size = 2
H, W = 6, 6
img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
h_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
w_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
cnt = 0
for h in h_slices:
    for w in w_slices:
        img_mask[:, h, w, :] = cnt
        cnt += 1

mask_windows = window_partition(img_mask, window_size)  # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)

attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

plt.matshow(img_mask[0, :, :, 0].numpy())
plt.text(0, 0, '0')
plt.matshow(attn_mask[0].numpy())
plt.matshow(attn_mask[1].numpy())
plt.matshow(attn_mask[2].numpy())
plt.matshow(attn_mask[3].numpy())

plt.show()

对于任意一张大小的特征图,都需要做相同大小的掩码模板。
(这里以 6 * 6 大小的特征图,窗口大小为 3 * 3,shift 为 2 为例)
生成 mask 模板示意图如下:
swin transformer 模块理解_第6张图片
左右两个图示对应的。颜色不同表示不同的掩码区域。

将上述的掩码区域进行如下代码操作:

# img_mask 的形状为(56,56)
# maks_windows 的形状为(64,7,7)
mask_windows = window_partition(img_mask, self.window_size)
# maks_windows 的形状为(64,49)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
# 这样相减,就形成了维度为(64, 49, 49)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
# 使用 -100.0 填充非零部分,其余部分使用 0.0 填充
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

给出 attn_mask 的示意图(示意图还是以 6*6 大小的特征图画的):
这里至于为什么这样做,大家可以看出,下面 0 的区域就是上面 mask 模板中的所有部分,非零区域就是多出来的部分。

swin transformer 模块理解_第7张图片
这里就做好了位置掩码操作,只需要在正向传播的过程中和特征图进行相加操作即可。这里对特征图的相加之后会经过【softmax】操作,因为特征图的值通常很小,加上 -100 之后,就会成为-100附近的值,然后经过 softmax 就变成了趋近于 0 的数,这样达到了掩码的作用。

5. downsample 操作

在原论文中,给出了这个流程图,可以注意到特征图大小的减少,和通道数的增加。
swin transformer 模块理解_第8张图片
下面给出 downsample 下采样代码

def forward(self, x):
    H, W = self.input_resolution
    B, L, C = x.shape
    assert L == H * W, "input feature has wrong size"
    assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
    x = x.view(B, H, W, C)
    x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
    x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
    x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
    x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
    x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
    x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
    x = self.norm(x)
    x = self.reduction(x)
    return x

这里的下采样操作和 【YOLOV5】中的 【Focus】的操作是类似的,大家可以看我的另一篇博客: [YOLOV5 模块理解]
就是将特征图的像素横纵每隔一个取出一个,然后得到四个高宽减半的特征图,然后进行拼接。

总结

这里对 swin transformer 模型的主要模块进行了理解。其中最复杂的就是 mask 部分,感觉还是不太清楚。

你可能感兴趣的:(swin,transfomer,transformer,深度学习,pytorch)