Link: https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation
结构总览:
Backbone: Swin Transformer
一系列 BasicLayer (Stage)
Decode_Head: UperHead
Auxiliary_Head: FCNHead
Intention: split image into non-overlapping patches
Just a conv2d: input shape: (B, C, H, W) output shape: (B, embed_dim, Wh, Ww)
nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
# 1. Window partition
(B, H, W, C)
(B, H // window_size, window_size, W // window_size, window_size, C)
(B, H // window_size, W // window_size, window_size, window_size, C)
(B * H // window_size * W // window_size, window_size, window_size, C)
mask_windows’s shape : (C* B * H // window_size * W // window_size, window_size * window_size)
Attn_mask’s shape: (C* B * H // window_size * W // window_size, 1, window_size * window_size) - (C* B * H // window_size * W // window_size, window_size * window_size, 1)
Efficient batch computation v1
number of windows increase: ceil(h/M) * ceil(w/M) -> ceil(h/M+1) * ceil(w/M+1)
window’s size diverse and are all small than the original one (M,M)
How to do batch computation efficiently??? Padding (add more computation) No!
The answer is Cyclic Shift
Now what we have to do is to do self-attention in window 1-9,
In order to do self-attention just like what did in M-WSA, we roll all window M//2
move every window to left M/2 and to top M/2, and then we can calculate the 5th window with W-MSA method, but other windows will get false results, with (6,4) mixed, (8,2) mixed, (1,3,7,9) mixed.
In order to fix the mixing problem, we have to add mask when doing self-attention,
def generate_attn_mask():
# 1. generate id for every effective window from 0-8
Hp = int(np.ceil(H / window_size)) * window_size
Wp = int(np.ceil(W / window_size)) * window_size
img_mask = torch.zeros([1, Hp, Wp, 1])
h_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
w_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
cnt = 0
for h in h_slices:
for w in slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# 2. generate true window(window_size, window_size) from img_mask
# ref: https://zhuanlan.zhihu.com/p/370766757
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size
mask_windows = mask_windows.view(-1, window_size * window_size) # nW, window_size*window_size
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW, window_size*window_size, window_size*window_size
# sigmoid -> 0 when x is very small!!
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, 0.0)
finally we got mask and then we can make it parallel~
Just the W-MSA is the difference from the traditional transformer block.
LayerNorm -> W-MSA -> LayerNorm -> MLP
shortcut = x
x = self.norm(x)
x = padding(x)
if shift_size > 0:
# cyclic shift which is already in attn_mask
x = torch.roll(x, (-shift_size, -shift_size), dims=(1,2))
# window partitions
x = window_partition(x, window_size)
x = x.view(-1, window_size*window_size, C)
# do window/shifted_window attention Parallel SW-MSA
if shift_size > 0:
x = self.attention(x, mask=attn_mask)
else:
x = self.attention(x)
x = x.view(-1, window_size, window_size, C)
x = window_reverse(x, window_size)
if shift_size > 0:
x = torch.roll(x, (shift_size, shift_size), dims=(1,2))
# LN + LayerNorm
x = shortcut + self.dropout(x)
x = x + self.dropout(self.mlp(self.norm2(x)))
ref: https://zhuanlan.zhihu.com/p/367111046
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
x = padding(x) # (N, H, W, C)
x0 = x[:, 0::2, 0::2, :] # (N, H/2, W/2, C)
x1 = x[:, 0::2, 1::2, :] # (N, H/2, W/2, C)
x2 = x[:, 1::2, 0::2, :] # (N, H/2, W/2, C)
x3 = x[:, 1::2, 1::2, :] # (N, H/2, W/2, C)
x = torch.cat([x0, x1, x2, x3], axis=-1) # (N, H/2, W/2, 4C)
x = x.view(N, -1, 4C)
x = self.norm(x)
x = self.reduction(x) # from 4C to 2C
the pooling feature in swin transformer is every output feature from every BasicLayer / Stage
https://zhuanlan.zhihu.com/p/367111046
https://zhuanlan.zhihu.com/p/370766757
https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation