SwinTransformer-Segmentation 代码解读

  • Backbone: Swin Transformer

    • Patch Embedding
      • 一系列 BasicLayer (Stage)

      • n个Swin Transformer Block
        • W-MSA / SW-MSA
        • FFN / MLP
    • Patch Merging
  • Decode_Head: UperHead

  • Auxiliary_Head: FCNHead

Patch Embedding (Patch Partition)

  1. Intention: split image into non-overlapping patches

  2. Just a conv2d: input shape: (B, C, H, W)  output shape: (B, embed_dim, Wh, Ww)

nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)



# 1. Window partition
(B, H, W, C)
(B, H // window_size, window_size, W // window_size, window_size, C)
(B, H // window_size, W // window_size, window_size, window_size, C)
(B * H // window_size * W // window_size, window_size, window_size, C)

mask_windows’s shape : (C* B * H // window_size * W // window_size, window_size * window_size)

Attn_mask’s shape:  (C* B * H // window_size * W // window_size, 1, window_size * window_size) - (C* B * H // window_size * W // window_size, window_size * window_size, 1)

Efficient batch computation v1

  1. number of windows increase: ceil(h/M) * ceil(w/M) -> ceil(h/M+1) * ceil(w/M+1)

  2. window’s size diverse and are all small than the original one (M,M)

How to do batch computation efficiently??? Padding (add more computation) No!

The answer is Cyclic Shift

Cyclic Shift 

Now what we have to do is to do self-attention in window 1-9,

In order to do self-attention just like what did in M-WSA, we roll all window M//2

move every window to left M/2 and to top M/2, and then we can calculate the 5th window with W-MSA method, but other windows will get false results, with (6,4) mixed, (8,2) mixed, (1,3,7,9) mixed.

In order to fix the mixing problem, we have to add mask when doing self-attention, 

def generate_attn_mask():

    # 1. generate id for every effective window from 0-8
    Hp = int(np.ceil(H / window_size)) * window_size
    Wp = int(np.ceil(W / window_size)) * window_size
    img_mask = torch.zeros([1, Hp, Wp, 1])
    h_slices = (slice(0, -window_size),
                slice(-window_size, -shift_size),
                slice(-shift_size, None))
    w_slices = (slice(0, -window_size),
                slice(-window_size, -shift_size),
                slice(-shift_size, None))
    cnt = 0
    for h in h_slices:
        for w in slices:
            img_mask[:, h, w, :] = cnt
            cnt += 1

    # 2. generate true window(window_size, window_size) from img_mask
    # ref: https://zhuanlan.zhihu.com/p/370766757
    mask_windows = window_partition(img_mask, window_size)  # nW, window_size, window_size
    mask_windows = mask_windows.view(-1, window_size * window_size) # nW,  window_size*window_size
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW,  window_size*window_size, window_size*window_size
    # sigmoid -> 0 when x is very small!!
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, 0.0)

 finally we got mask and then we can make it parallel~



Just the W-MSA is the difference from the traditional transformer block.

LayerNorm -> W-MSA -> LayerNorm -> MLP

shortcut = x
x = self.norm(x)
x = padding(x)

if shift_size > 0:
    # cyclic shift which is already in attn_mask
    x = torch.roll(x, (-shift_size, -shift_size), dims=(1,2))

# window partitions
x = window_partition(x, window_size)
x = x.view(-1, window_size*window_size, C)

# do window/shifted_window attention  Parallel SW-MSA
if shift_size > 0:
    x = self.attention(x, mask=attn_mask)
    x = self.attention(x)

x = x.view(-1, window_size, window_size, C)
x = window_reverse(x, window_size)

if shift_size > 0:
    x = torch.roll(x, (shift_size, shift_size), dims=(1,2))

# LN + LayerNorm
x = shortcut + self.dropout(x)
x = x + self.dropout(self.mlp(self.norm2(x)))

 Patch Merging

ref: https://zhuanlan.zhihu.com/p/367111046

self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)

x = padding(x)  # (N, H, W, C)
x0 = x[:, 0::2, 0::2, :]  # (N, H/2, W/2, C)
x1 = x[:, 0::2, 1::2, :]  # (N, H/2, W/2, C)
x2 = x[:, 1::2, 0::2, :]  # (N, H/2, W/2, C)
x3 = x[:, 1::2, 1::2, :]  # (N, H/2, W/2, C)

x = torch.cat([x0, x1, x2, x3], axis=-1)  # (N, H/2, W/2, 4C)
x = x.view(N, -1, 4C)
x = self.norm(x)
x = self.reduction(x)  # from 4C to 2C


the pooling feature in swin transformer is every output feature from every BasicLayer / Stage

