Swin Transformer 结构&代码解析学习

Swin Transformer 结构&代码解析学习_第1张图片


目录

  • 前言摘要
  • 一、网络总体结构及代码框架
  • 二、各部分方法&代码解析
    • 1. Patch_Embed
    • 2. Patch Merging
    • 3. Swin Transformer Block
      • 3.1 Window Multi-head Self Attention (W-MSA)
      • 3.2 Shifted Window Multi-head Self Attention (SW-MSA)
      • 3.3 Relative Position Bias*
  • 总结


前言摘要

  文章提出了一种新的ViT(Vision Transformer)作为计算机视觉任务的通用主干。而为了解决图像与NLP在数据规模和分辨率上存在的差异,设计了一种类似于ResNet等传统卷积网络类似的分层(Stage)结构,对于不同尺度的目标更具灵活性。同时引入滑窗(Shifted Window)来进行非重叠局部窗口的自注意力计算;滑窗也允许跨窗口patch的连接,在降低计算量的同时实现不同窗口区域内容的交互。

论文名称: Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
论文地址: ICCV 2021 open access
代码地址: https://github.com/microsoft/Swin-Transformer

代码源自@太阳花的小绿豆,特别感谢!
导师代码地址: GitHub
导师代码讲解: https://www.bilibili.com/video/BV1yg411K7Yc


一、网络总体结构及代码框架

  网络的总体结构如Fig.1所示,首先通过tokenization方法(由Patch ParititionLinear Embedding组成)将输入图像生成token;而后通过四个不同的stage来构建尺度不同的特征图针对下游任务,每个stage中包含W-MSA模块,将特征图划分成了多个不相交的窗体(Window),且MSA注意力交互只在每个窗体(Window)内进行。相对于ViT对全局进行Multi-Head Self-Attention能够减少计算量,尤其是在浅层特征图分辨率很大的时候。然而W-MSA阻碍不同窗口之间的信息传递,所以文章也提出了SW-MSA模块,通过此方法能够实现跨窗口的信息交互;同时在不同stage间作者提出了Patch Mergring下采样方法实现对token的下采样。
Swin Transformer 结构&代码解析学习_第2张图片

Fig.1 Swin Transformer网络结构
  • Patch_Embed:三通道彩色图像在输入网络前需要token化。类似于ViT,对于第一个stage前的Patch Paritition和Linear Embedding采用Patch_Embed方法统一实现,具体是通过一个卷积层并展平(flatten)完成。
  • Patch_Mergring:由于网络是类似于深度卷积的层次化stage结构,在每一个stage输出后需要对特征进行下采样,尺度减半通道数翻倍;步骤是对token重构的特征进行隔行采样并拼接得到4块patch,而后通过LN层和线性映射压缩通道数。
  • Swin Transformer Block:滑窗注意力W-MSA及SW-MSA构成位置,这两个结构是先后成对使用的。基本结构和组成类似于ViT的transformer Block,包括MSA后的FNN当中的MLP和LN等;同时针对相应问题引入相对位置偏差和窗体分块掩码。

Swin Transformer 结构&代码解析学习_第3张图片

Fig.2 Swin Transformer模型代码调用结构

模型实现代码结构Fig.2所示,下文会对具体模块和代码部分进行分析解读。


二、各部分方法&代码解析

1. Patch_Embed

  文章结构图当中的Patch ParititionLinear Embedding部分实际由PatchEmbed类来实现。初始化类时对patch_size尺寸、输入图像通道数in_channels和embed_dim维度进行定义。通过Conv卷积操作来实现嵌入。卷积核大小和步长都是patch_size,卷积核个数为embed_dim。输入维度为(B,C,H,W),输出维度为(B,H'*W',embed_dim)。代码如下:

class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = (patch_size, patch_size) 
        self.patch_size = patch_size # 每块patch的尺寸,也是卷积核尺寸和步长
        self.in_chans = in_c # 输入图像的维度,通道数
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        _, _, H, W = x.shape
        # padding
        # 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
        pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
        if pad_input:
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
                          0, self.patch_size[0] - H % self.patch_size[0],
                          0, 0))
        # 下采样patch_size倍
        x = self.proj(x)
        _, _, H, W = x.shape
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = x.flatten(2).transpose(1, 2) # 展平并交换维度
        x = self.norm(x) 
        return x, H, W

2. Patch Merging

Swin Transformer 结构&代码解析学习_第4张图片

Fig.3 Patch Merging实现步骤

  除第一stage以外,每个stage阶段在Swin Transformer Block前都需进行Patch Merging操作,主要目的是完成下采样生成不同尺度的特征,具体步骤如Fig.3所示。对于输入维度为(B,H*W,C)的特征,实现方法是利用间隔采样生成四组宽高较输入减半的patch,而后对四组patch进行concatenate拼接,再通过Layer_Norm和线性映射层得到的输出维度为(B,H/2*W/2,2C)。代码如下:

class PatchMerging(nn.Module):
    r""" 
    Patch Merging Layer.
    """
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        
        x = x.view(B, H, W, C)  # 对应步骤1,对输入tensor改变形状,还原为(B, H, W, C)
        # 如果输入feature map的H,W不是2的整数倍,需要进行padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
		
        x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
        #上述对应步骤2,隔行采样生成四组patch,分别为x0,x1,x2,x3
        x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
        # 对应步骤3,对四组patch进行拼接
        x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]
		
        x = self.norm(x) # 对应步骤4,进行LN层计算
        x = self.reduction(x)  # [B, H/2*W/2, 2*C] #对应步骤5,线性映射改变通道数

        return x

3. Swin Transformer Block

  Swin Transformer Block相较于基本ViT,除了常规的MLP多层感知机、LN层及残差连接方法外,主要引入了W-MSA和SW-MSA,基本思想是只在固定窗体区域内对元素进行注意力计算,同时为了使得全局信息能够有效交互,采取滑动窗口注意力计算SW-MSA。同时为了解决滑动窗口引入的问题还引入了相对位置偏置滑动窗口分块掩码

3.1 Window Multi-head Self Attention (W-MSA)

Swin Transformer 结构&代码解析学习_第5张图片

3.2 Shifted Window Multi-head Self Attention (SW-MSA)

Swin Transformer 结构&代码解析学习_第6张图片
Swin Transformer 结构&代码解析学习_第7张图片

Fig.4 窗体滑动方和掩码mask的引入

  SW-MSA为了使得不同Window区域内部的元素进行注意力交互,为窗口Window添置了一个偏移量。具体的窗口滑动方法如Fig.4所示。为了解决窗口滑动后跨图像区域-如图4中A和B区域交互的不合理情况,设置和mask掩码以消除交互结果,只让在原始 图像上真实属于同一区域的特征进行交互。具体做法是创建mask,源自不同区域的位置索引值赋值-100,在计算时通过softmax最终趋近于0忽略不计。创建掩码的过程如下:

    def create_mask(self, x, H, W):
        # calculate attention mask for SW-MSA
        # 保证Hp和Wp是window_size的整数倍
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        # 拥有和feature map一样的通道排列顺序,方便后续window_partition
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # [1, Hp, Wp, 1]
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        # 通过切片划分出来源不同patch的元素
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1
		# 通过上述,对来源于不同区域的元素分别赋予不同的编码
        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
        # 按照window_size划分为窗口格式
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  
        # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1] = [nW, Mh*Mw, Mh*Mw]
        # 来源相同区域的元素位置索引值为0,不同区域的为除0以外的其他值
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        # 对为0的位置赋值0,非0位置赋值-100
        return attn_mask

  而后则构建Swin Transformer Block类并实现W-MSA和SW-MSA,结构如Fig.5所示。根据传入参数shift_size来判断MSA种类。构建MSA后完成:

窗口化→W-MSA/SW-MSA计算→窗口复原特征→FFN+残差连接

Swin Transformer 结构&代码解析学习_第8张图片

Fig.5 W-MSA和SW-MSA结构图

进而继续完成FFN和残差连接的计算。至此整个Block构建完成。Swin Transformer Block代码如下:


class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
            attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, attn_mask):
        H, W = self.H, self.W
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # pad feature maps to multiples of window size
        # 把feature map给pad到window size的整数倍
        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x
            attn_mask = None

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # [nW*B, Mh, Mw, C]
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # [nW*B, Mh*Mw, C]

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=attn_mask)  # [nW*B, Mh*Mw, C]

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # [nW*B, Mh, Mw, C]
        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # [B, H', W', C]

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        if pad_r > 0 or pad_b > 0:
            # 把前面pad的数据移除掉
            x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

3.3 Relative Position Bias*

Fig.6 相对位置偏置的计算过程
  在每个Head进行自注意力计算时,遵循Relative Position Bias方法在计算相似性时包括相对位置偏差项B,首先构建二维的相对位置索引,根据索引去查找对应表中的相对位置偏置值,而table表为Parameter类,是参与训练迭代过程进行学习的参数。最终将表内值映射至索引位置,完成相对位置偏置的计算。如**Fig.6**所示。而添加相对偏置的效果是最好的,如**Table.1**所示。
Table.6 位置偏置对比实验结果

Swin Transformer 结构&代码解析学习_第9张图片
而索引和表的定义初始化代码如下所示:

   self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # [2*Mh-1 * 2*Mw-1, nH]

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        # coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))
        coords = torch.stack(torch.meshgrid([coords_h, coords_w],))
        # [2, Mh, Mw] 对生成的两个tensor进行拼接
        coords_flatten = torch.flatten(coords, 1)  # [2, Mh*Mw] 展平处理
        # [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, Mh*Mw, Mh*Mw] 相减求相对位置编码矩阵
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # [Mh*Mw, Mh*Mw, 2]
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # [Mh*Mw, Mh*Mw]
        self.register_buffer("relative_position_index", relative_position_index)  # 将参数放置于模型当中
        # 以上代码生成relative_position_index

总结

作为ICCV 2021 Best Paper,本文主要将分层stage结构引入了ViT领域当中,引入窗体自注意力在极大降低计算量的同时利用SW-MSA实现跨窗特征的交互,有效利用全局信息。在代码方面,分区mask掩码和相对位置偏置的部分十分惊艳(微软炫技了),但理解难度也较大。开创了分层ViT和滑窗注意力的领域。

你可能感兴趣的:(深度学习_充电,transformer,深度学习,attention,计算机视觉,人工智能)