Transformer在图像中的运用(三)SwinTransformer原理及代码解读

说之前先提一个视频这个视频还是很好的将transformer机制的变迁及未来的趋势很详细的说明了一下我觉得蛮有感触的,建议可以看看这里首先提一下代码及其对应的论文视频地址。
paper:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
code: microsoft/Swin-Transformer
可以理解SwinTransformer是新一代的特征提取神器,很多榜单都有它的影子,这里我们可以理解为是一种新的`backbone,如下所示支持多种下游任务。相对比之前说的Transformer 在图像中的运用(一)VIT(Transformers for Image Recognition at Scale)论文及代码解读 之前需要每个像素

一、 原理

在Transformer种,如果图像像素太多则我们需要构建出更多的特征序列,这样就会导致我们的效率降低,所以我们采用了窗口以及分层的形式来替代长序列。

1.1 整体网络架构

  • 得到各Patch特征构建的序列(注意这里先卷积得到特征图,再对特征图进行切分成Patch
  • 分成计算attention(逐步下采样过程)
  • 其中Block是最核心的, 对attention的计算方法进行了改进

由下面的图我们可以看出特征图大小不断减小, 但是特征图的通道数不断增加。


Swin整体网络结构
1.1.1 Patch Embedding

下面举一个例子比如输入的图像数据为(224, 224, 3), 输出(3136, 96)相当于序列长度为3136, 每个向量是96维特征。这里的卷积核我们使用Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))。所以3136就是卷积(224 / 4) * (224 / 4)得到的。

这时候我们得到的输入特征图为(56, 56, 96), 如果默认窗口大小为7,所以总共可以分为8 * 8个窗口。则输出的特征图为(64(8*8), 7, 7, 96) 之前单位是序列, 现在单位是窗口(工64个窗口)

1.1.2 Swin Transformer Block

下面我们来看下上面图中对应的Transformer Blocks是什么样子, 如下图所示。

Swin Transformer Block

上图的两个组合是串联而成的Block,对于左边为基于窗口的注意力计算W-MSA(multi-head self attention modules with regular),对于右边为窗口滑动后重新计算注意力SW-MSA(multi-head self attention modules with shifted windowing)

1. W-MSA(计算每个不同窗口自身的注意力机制(下面不同颜色的矩形代表不同的窗口))


对得到的窗口,计算各个窗口自己的自注意力得分,qkv三个矩阵放在一起得到(3, 64, 3, 49, 32)

  • 3个矩阵
  • 64个窗口
  • 3个heads
  • 7*7的窗口大小(每个窗口有49个token即49个像素)
  • 96/3=32个单head特征

所以attention结果为(64, 3, 49, 49) 每个头都会得出每个窗口内的自注意力(3为头,这里可以理解为不同窗口不同头对应窗口的不同token之间的注意力)。
通过上面的计算我们可以得到新的特征(64, 49, 96), 之后再进行reshape操作将其还原到(56, 56, 96)大小特征图目的就是为了还原输入特征图大小(但是其已经计算过了attentation), 因为再transformer要经过多层输入大小与输出大小一般都是相同的。
\color{red}{这里顺便提下主要就是有这篇论文windows机制相对于VIT来说需要对特征图上的每个像素相互进行QKV运算,进行信息沟通。 而这里采用了windows机制,} \color{red}{将特征图分成一个个windows,我们只在每个windows内部进行MSA,可以大大减少计算量, 但是有一个缺点就是窗口之间是无法进行信息交互的, 从而导致我们的感受野变小}下面给出了省出来的计算量。

h,w,c分别代表特征图高度宽度和深度, M代表窗口大小


矩阵计算评估计算量

这里计算量公式可以参考这篇文章Swin-Transformer网络结构详解。

2. SW-MSA(计算不同窗口之间的注意力机制)
上面W-MSA是只是知道窗口内部的特征,但是我们不知道窗口之间的特征我们可以用SW-MSA机制来弥补。这里的主要区别就是S(shift滑动),我们如何去做滑动呢?

transformer偏移

上图中我们可以看出网格由红色网格(b)移动到了蓝色网格(c),我们需要通过将上方蓝色区域移动到下方,左边红色区域移动到右边。这么做的目的如下:
https://www.zhihu.com/question/492057377/answer/2213112296

记住这里是半个窗口, 还有一点记住是向下取整(如窗口大小3, 则移动为1)
说白了就是换一换所有不同窗口的匹配对,使得模型更加健壮,这就是滑动操作。

由于不同Windows之间互不重叠,每次进行自注意力计算时很显然就丢失了Windows之间的信息,那么如何在降低计算量的同时保留全局信息呢?Shifted Window应运而生。


上面这张图可以用如下的示意图理解:




但是还有一个问题原来是4个windows,但是移动之后变成了9个windows,为了能够做到并行计算应该如何解决呢?我们可以做如下偏移方法。



则得到如下效果:

Attention Mask 机制
因为我们区域(5,3) (7,1) (8,6,2,0)本来是之间不想连接的,所以我们要单独计算各自的区域的MSA。我们借用区域(5,3)举例,这篇博客对于这个解释非常棒Swin-Transformer网络结构详解, 如下所示:


这里我们仅仅计算区域5的信息而不想引入区域3的信息,我们通过掩码mask的方式即可计算。因为本来公式中 是一个很小的数字如果我们减去100, 再经过softmax可以理解为就是为0了。
示例1

示例2

注意,全部计算完后需要将数据挪回到原来的位置上。下面演示一下整体流程
流程1

流程2

流程3

因为要经过多层transformer通过W-MSA以及SW-MSA输出的大小保持不变(56*56*96)

1.1.2 Relative Position Bias


下面我们看下加相对偏置与不加相对偏置的效果
Table 4. Ablation study on the shifted windows approach and different position embedding methods on three benchmarks, using the Swin-T architecture. w/o shifting: all self-attention modules adopt regular window partitioning, without shifting; abs. pos.: absolute position embedding term of ViT; rel. pos.: the default settings with an additional relative position bias term (see Eq. (4)); app.: the first scaled dot-product term in Eq. (4).

发现使用rel.pos相对位置偏置更加合理。
上述相对位置小矩阵摊平再拼接就得到下面的大矩阵

如何将一元坐标转成二元坐标呢?我们看作者如何去做的。
偏移从0开始行、列标加上M(窗口大小->2*2)-1

行标乘上2M-1

行列相加

上述就可以得出我们下面的公式B
1.1.3 PatchMerging

network structure

这里我们就要说到这里Patch Merging操作。它的作用可以缩小特征图大小,提升特征图的通道数(这里也可以理解为就是下采样操作)。

二、 代码逻辑解读

# file: models/swin_transformer.py
# class: SwinTransformer
class SwinTransformer(nn.Module):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030

    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        # 这里的drop rate是会随着模型不同stage不断提升到我们设定的rate
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), # 我们的深度不断乘上2
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, # 这里transoformer和patchMerge是连在一起的最后一个没有transformer只有patchMerge
                               use_checkpoint=use_checkpoint) 
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)

2.1 input embedding

# file: models/swin_transformer.py
# class: SwinTransformer
    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x

这里我们的输入大小为4(batch), 3(channel), 224(width), 224(height), 接着进入到self.patch_embed操作。

# file: swin_transformer.py
# class: PatchEmbed
    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

和以往vit一样,这里做self.proj就是进行卷积操作

# 卷积核大小为4, stride也是为4, 这样会导致输出特征图为原来的额1/4 -> (56 * 56)
# 输入输出channel分别为3和96
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
# 这部分flatten操作是将我们的宽度高度展平,输出shape为(4, 3136(56*56), 96)
x = self.proj(x).flatten(2).transpose(1, 2)

在经过self.norm对应的操作为nn.LayerNorm
接着我们会经过我们的self.pos_drop(x), 这里的self.pos_dropnn.Dropout(p=drop_rate)操作。
接着进行下面各个层的操作(别忘记此时我们的输入shape为
(4, 3136(56*56), 96))

      for layer in self.layers:
            x = layer(x)

2.2 Basiclayer

接着上面我们看一下self.layers是如何构建的

# file: models/swin_transformer.py
# class: SwinTransformer
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               use_checkpoint=use_checkpoint)
            self.layers.append(layer)

# file: models/swin_transformer.py
# class: BasicLayer
class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

2.2 SwinTransformerBlock

# file: models/swin_transformer.py
# class: SwinTransformerBlock
class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x
2.2.1 W-MSA及SW-MSA输入

我们知道输入是先经过W-MSA再经过SW-MSA
经过W-MSA是没有做任何处理的即代码中shifted_x = x, 但是对于W-MSA是通过torch.roll的操作进行的,代码如下所示:

shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))

这里有1和2,分别表示要左右上下移动,还有就是这里的self.shift_size为负数,说明移动完处理之后这里还是要复原的
如下代码所示:

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C # 第一个block得到(4, 56, 56, 96)

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

最终得到的shape依然是我们原来的输入(4, 3136, 96)
接着下进入如下操作

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

比如一开始第一个block我们的得到的第一个输出shape为(256, 7, 7, 96) 然后我们得到第二个windows为(256, 49, 96)。相当于256windows, 每个windows49像素, 每个像素96个维度。
对于上面代码中的window_partition代码如下:

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

这里的xshape为(4, 8, 7, 8, 7, 96), 我们可以得到windows的数量为
(H/windows_size) * (W/windows_size) * batch, 这里W,H一开始都为56, windows_size7, 这里设置的batch4, 因此这里我们最终得到的windows shape为(256 7 7 96)

2.2.2 Attention机制

上面的输出之后我们要经过我们的Attention机制。

attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

如果x_windowsW-MSAself.atten_maskNone, 否则会加入atten_mask, 具体代码如下(详细理解可以参考bilibili, 在31分钟左右 说的非常好):

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

对应上述代码简单点就是再不需要做内积的地方填入-100, 这样经过softmax的时候就被自动设置为0了。
下面先给出我们进入attention的代码。


class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

    def flops(self, N):
        # calculate flops for 1 window with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        return flops

可以看出首先会经过self.qkv生成我们的q, k, v矩阵,内部代码就是很简单的nn.Linear,

# 这里的`dim`, 我们设置为96
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
# 这里我们得到的self.qkv shape 为[3, 256, 3, 49, 32] 这里的3分别对应qkv, 
# 256个窗口分别做attention, 
# 刚开始head为3, 
# 每个窗口有49个元素, 
# 32 代表每个头有32个维度
# q, k, v shape分别为[256, 3, 49, 32]
q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple) 

接着用得到我们的注意力机制,如下所示,这里的self.scale可以理解为我们的v

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

最终让我们attentionposition bias相加, 如下所示获得我们最终的atten

attn = attn + relative_position_bias.unsqueeze(0)

这里的position bias 下面解释。

2.2.3 Relative Position Bias Table

我们上面说了相对位置偏置矩阵的大小为(2M-1) * (2M-1), 这里的Mwindows-size大小(详细理解可以参考bilibili, 在56分钟左右 说的非常好)。

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

下面就是之前说的经softmax, 如果mask不相同索引的我们设置为-100, 经过softmax计算就变成了0.

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

在经过

x = (attn @ v).transpose(1, 2).reshape(B_, N, C)

操作之后我们得到的attention之后的向量为(256, 49. 96), self.proj_drop为drop_out

2.2.4 FFN(残差操作)
图中红色蓝色框的部分做了两种残差

最后要做一次残差连接

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

上述说完就完成了我们SwinTransformerBlock的部分了。

3. Patch Merging

Swin整体网络结构

通过结构图我们可以看出经过Swin Transformer Block之后会经过Patch Merging层,原理如下图所示。


对应的代码如下:

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

4. 输出层

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
x = self.norm(x)  # B L C
x = self.avgpool(x.transpose(1, 2))  # B C 1
x = torch.flatten(x, 1)

经过平均池化将原来shape由(4, 49, 768)转成(4, 768, 1)后面再接一下全连接层
nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()即可。

参考
[1] Swin Transformer
[2] 如何看待swin transformer成为ICCV2021的 best paper?
[3] Swin-Transformer网络结构详解

你可能感兴趣的:(Transformer在图像中的运用(三)SwinTransformer原理及代码解读)