论文链接:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
源码地址:https://github.com/microsoft/Swin-Transformer
来看看为什么Swin Transformer能屠榜吧!
Transformer做视觉有两个大的挑战:
为了解决上述问题,就有了Swin Transformer。顾名思义,Hierarchical(多层级)解决第一个问题;Shifted Windows(滑窗)解决第二个问题。
如图所示,Swin Transformer通过融合图片块构建多层级的特征图。同时,使计算复杂度与输入图片线性相关,一个window包含若干个patch,仅在window内部计算self-attention。由于window的patch固定,所以计算复杂度与输入图片线性相关。这也就是Shifted Windows,是Swin的缩写,也是本篇文章最精彩的部分。
虽然在window内部计算self-attention可能大大降低模型的复杂度,但是不同window无法进行信息交互,从而表现力欠缺。为了更好的增强模型的表现能力,引入Shifted Windows Attention。Shifted Windows是在连续的Swin Transformer blocks之间交替移动的。
一般的Shifted window partition操作如下图:
window partition源码:
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
阅读源码后发现,源码中也没有实现windows由4个变成9个的操作,而且当window_size为奇数时会报错,也不必过分纠结于此,因为实际的操作是通过下面更有效地方法计算的。
通过给Attention加mask实现,限制自注意力计算量,在子窗口中计算。
详解见文章:【Pytorch小知识】torch.roll()函数的用法及在Swin Transformer中的应用(详细易懂)
源码中的部分:
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
这应该是本篇论文最精彩的想法,通过mask使shifted window attention和window attention在相同窗口下的计算结果等价,完美解决了上面的window不一致问题,可以对非规则window计算attention。这部分论文中没有阐述,只能结合代码看一下:
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0,
还是以4×4输入为例说明。
window_size=2
shift_size=1
#x = torch.randn(1,8,8,3)
#x.shape
H = 4
W = 4
去掉self后的代码:
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / window_size)) * window_size
Wp = int(np.ceil(W / window_size)) * window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
#print("img_mask:",img_mask)
h_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
w_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
#print("h_slices:",h_slices)
#print("w_slices:",w_slices)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
print("img_mask",img_mask)
cnt += 1
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
#print("mask_windows:",mask_windows)
mask_windows = mask_windows.view(-1, window_size * window_size)
#print("mask_windows:",mask_windows)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
#print("mask_windows:",attn_mask)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
#print("mask_windows:",attn_mask)
mask_windows: tensor([[[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.]],
[[ 0., -100., 0., -100.],
[-100., 0., -100., 0.],
[ 0., -100., 0., -100.],
[-100., 0., -100., 0.]],
[[ 0., 0., -100., -100.],
[ 0., 0., -100., -100.],
[-100., -100., 0., 0.],
[-100., -100., 0., 0.]],
[[ 0., -100., -100., -100.],
[-100., 0., -100., -100.],
[-100., -100., 0., -100.],
[-100., -100., -100., 0.]]])
其中,四个mask对应关系为:
这是feature map在roll操作后的结果,将每个部分拉直进行QKT操作,即可得到对应的mask结果。参考图解Swin Transformer中的Attention Mask部分:
得到上边代码的mask结果。
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
softmax之后,值为-100的元素会被忽略,从而达到mask的效果,仅得到window中有效的部分的attention。
再reverse回去就达到和原先计算结果一致的目的。类似于CNN中提取特征局部计算的过程。