SwinTransformer-Segmentation 代码解读

Link: https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation

结构总览:

SwinTransformer-Segmentation 代码解读_第1张图片

  • Backbone: Swin Transformer

    • Patch Embedding
      • 一系列 BasicLayer (Stage)

      • n个Swin Transformer Block
        • W-MSA / SW-MSA
        • FFN / MLP
    • Patch Merging
  • Decode_Head: UperHead

  • Auxiliary_Head: FCNHead

Patch Embedding (Patch Partition)

  1. Intention: split image into non-overlapping patches

  2. Just a conv2d: input shape: (B, C, H, W)  output shape: (B, embed_dim, Wh, Ww)

nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

BasicLayer

W-MSA / SW-MSA

# 1. Window partition
(B, H, W, C)
(B, H // window_size, window_size, W // window_size, window_size, C)
(B, H // window_size, W // window_size, window_size, window_size, C)
(B * H // window_size * W // window_size, window_size, window_size, C)

mask_windows’s shape : (C* B * H // window_size * W // window_size, window_size * window_size)

Attn_mask’s shape:  (C* B * H // window_size * W // window_size, 1, window_size * window_size) - (C* B * H // window_size * W // window_size, window_size * window_size, 1)

SwinTransformer-Segmentation 代码解读_第2张图片

Efficient batch computation v1

SwinTransformer-Segmentation 代码解读_第3张图片

  1. number of windows increase: ceil(h/M) * ceil(w/M) -> ceil(h/M+1) * ceil(w/M+1)

  2. window’s size diverse and are all small than the original one (M,M)

How to do batch computation efficiently??? Padding (add more computation) No!

The answer is Cyclic Shift

Cyclic Shift 

SwinTransformer-Segmentation 代码解读_第4张图片

Now what we have to do is to do self-attention in window 1-9,

In order to do self-attention just like what did in M-WSA, we roll all window M//2

SwinTransformer-Segmentation 代码解读_第5张图片

move every window to left M/2 and to top M/2, and then we can calculate the 5th window with W-MSA method, but other windows will get false results, with (6,4) mixed, (8,2) mixed, (1,3,7,9) mixed.

In order to fix the mixing problem, we have to add mask when doing self-attention, 

def generate_attn_mask():

    # 1. generate id for every effective window from 0-8
    Hp = int(np.ceil(H / window_size)) * window_size
    Wp = int(np.ceil(W / window_size)) * window_size
    img_mask = torch.zeros([1, Hp, Wp, 1])
    h_slices = (slice(0, -window_size),
                slice(-window_size, -shift_size),
                slice(-shift_size, None))
    w_slices = (slice(0, -window_size),
                slice(-window_size, -shift_size),
                slice(-shift_size, None))
    
    cnt = 0
    for h in h_slices:
        for w in slices:
            img_mask[:, h, w, :] = cnt
            cnt += 1

    # 2. generate true window(window_size, window_size) from img_mask
    # ref: https://zhuanlan.zhihu.com/p/370766757
    mask_windows = window_partition(img_mask, window_size)  # nW, window_size, window_size
    mask_windows = mask_windows.view(-1, window_size * window_size) # nW,  window_size*window_size
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW,  window_size*window_size, window_size*window_size
    # sigmoid -> 0 when x is very small!!
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, 0.0)

SwinTransformer-Segmentation 代码解读_第6张图片

 finally we got mask and then we can make it parallel~

SwinTransformerBlock

structure

SwinTransformer-Segmentation 代码解读_第7张图片

Just the W-MSA is the difference from the traditional transformer block.

LayerNorm -> W-MSA -> LayerNorm -> MLP

shortcut = x
x = self.norm(x)
x = padding(x)

if shift_size > 0:
    # cyclic shift which is already in attn_mask
    x = torch.roll(x, (-shift_size, -shift_size), dims=(1,2))
    

# window partitions
x = window_partition(x, window_size)
x = x.view(-1, window_size*window_size, C)

# do window/shifted_window attention  Parallel SW-MSA
if shift_size > 0:
    x = self.attention(x, mask=attn_mask)
else:
    x = self.attention(x)

x = x.view(-1, window_size, window_size, C)
x = window_reverse(x, window_size)

if shift_size > 0:
    x = torch.roll(x, (shift_size, shift_size), dims=(1,2))

# LN + LayerNorm
x = shortcut + self.dropout(x)
x = x + self.dropout(self.mlp(self.norm2(x)))

 Patch Merging

ref: https://zhuanlan.zhihu.com/p/367111046

self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)

x = padding(x)  # (N, H, W, C)
x0 = x[:, 0::2, 0::2, :]  # (N, H/2, W/2, C)
x1 = x[:, 0::2, 1::2, :]  # (N, H/2, W/2, C)
x2 = x[:, 1::2, 0::2, :]  # (N, H/2, W/2, C)
x3 = x[:, 1::2, 1::2, :]  # (N, H/2, W/2, C)

x = torch.cat([x0, x1, x2, x3], axis=-1)  # (N, H/2, W/2, 4C)
x = x.view(N, -1, 4C)
x = self.norm(x)
x = self.reduction(x)  # from 4C to 2C

 FCNHead

SwinTransformer-Segmentation 代码解读_第8张图片

the pooling feature in swin transformer is every output feature from every BasicLayer / Stage

All Reference

  • https://zhuanlan.zhihu.com/p/367111046

  • https://zhuanlan.zhihu.com/p/370766757

  • https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation

你可能感兴趣的:(SwinTransformer-Segmentation 代码解读)