第6周学习笔记:Vision Transformer & Swin Transformer学习

Vision Transformer模型详解

该模型将Transformer结构直接应用到图像上,即将一张图像分割成多个patches,这些patches看作是NLP的tokens (words),然后对每个patches做一系列linear embedding操作之后作为Transformer的input。

Vision Transformer 模型由三个模块组成:

  • Linear Projection of Flattened Patches(Embedding层)
  • Transformer Encoder(Transformer 层)
  • MLP Head(最终用于分类的层)
    第6周学习笔记:Vision Transformer & Swin Transformer学习_第1张图片

Linear Projection of Flattened Patches(Embedding层)

对于标准的Transformer模块,要求输入的是token(向量)序列,即二维矩阵[num_token, token_dim],如下图,token0-9对应的都是向量,以ViT-B/16为例,每个token向量长度为768。
第6周学习笔记:Vision Transformer & Swin Transformer学习_第2张图片

左边是一个个被切分好的图片块,假设原始输入的图片数据是 H x W x C,我们需要对图片进行块切割,假设图片块大小为S1 x S2,则最终的块数量N为:N = ( H / S1) * (W / S2)。然后将切分好的图片块展平为一维,那么每一个向量的长度为:Patch_dim = S1 * S2 * C,从而得到了一个N x Patch_dim的输入序列。

Transformer Encoder(Transformer 层)

Vision Transformer Encoder 有层归一化,多头注意力机制,残差连接和线性变换这四个操作

  • 给定输入编码矩阵 ,首先将其进行层归一化得到 ;
  • 利用矩阵 对 进行线性变换得到矩阵,再将矩阵输入到 Multi-Head Attention中得到矩阵 ,将最原始的输入矩阵 与 进行残差计算得到 ;
  • 将 进行第二次层归一化得到 ,然后再将 输入到全连接神经网络中进行线性变换得到 。最后将 与 进行残差操作得到该 Block 的输出;。一个 Encoder 可以将 个 Block 进行堆叠。
    第6周学习笔记:Vision Transformer & Swin Transformer学习_第3张图片
    其中Multi-Head Attention就是让模型学习全方位、多层次、多角度的信息,学习更丰富的信息特征,对于同一张图片来说,每个人看到的、注意到的部分都会存在一定差异,而在图像中的多头恰恰是把这些差异综合起来进行学习。

MLP Head(最终用于分类的层)

结束了Transformer Encoder,就到了最终的分类处理部分,在之前进行Encoder的时候通过concat的方式多加了一个用于分类的可学习向量,这时把这个向量取出来输入到MLP Head中,即经过Layer Normal --> 全连接 --> GELU --> 全连接,得到了最终的输出。
第6周学习笔记:Vision Transformer & Swin Transformer学习_第4张图片

Swin Transformer 网络详解

目前Transformer应用到图像领域主要有两大挑战:

  • 视觉实体变化大,在不同场景下视觉Transformer性能未必很好
  • 图像分辨率高,像素点多,Transformer基于全局自注意力的计算导致计算量较大

针对上述两个问题,提出了一种包含滑窗操作,具有层级设计的Swin Transformer。
其中滑窗操作包括不重叠的local window,和重叠的cross-window。将注意力计算限制在一个窗口中,一方面能引入CNN卷积操作的局部性,另一方面能节省计算量。
Swin Transformer的最大贡献是提出了一个可以广泛应用到所有计算机视觉领域的backbone,并且大多数在CNN网络中常见的超参数在Swin Transformer中也是可以人工调整的,例如可以调整的网络块数,每一块的层数,输入图像的大小等等。

网络整体架构

通过与CNN相似的分层结构来处理图片,使得模型能够灵活处理不同尺度的图片
第6周学习笔记:Vision Transformer & Swin Transformer学习_第5张图片
接下来,简单看下原论文中给出的关于Swin Transformer(Swin-T)网络的架构图,可以看出整个框架的基本流程如下:
第6周学习笔记:Vision Transformer & Swin Transformer学习_第6张图片
整个模型采取层次化的设计,一共包含4个Stage,每个stage都会缩小输入特征图的分辨率,像CNN一样逐层扩大感受野。

  • 在输入开始的时候,做了一个Patch Embedding,将图片切成一个个图块,并嵌入到Embedding。
  • 在每个Stage里,由Patch Merging和多个Block组成。
  • 其中Patch Merging模块主要在每个Stage一开始降低图片分辨率。
  • Block具体结构主要是LayerNorm,MLP,Window Attention 和 Shifted Window Attention组成
class SwinTransformer(nn.Module):
    def __init__(...):
        super().__init__()
        ...
        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            
        self.pos_drop = nn.Dropout(p=drop_rate)

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(...)
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

Patch Embedding详解

在输入进Block前,我们需要将图片切成一个个patch,然后嵌入向量。
具体做法是对原始图片裁成一个个 patch_size * patch_size的窗口大小,然后进行嵌入。
这里可以通过二维卷积层,将stride,kernelsize设置为patch_size大小。设定输出通道来确定嵌入向量的大小。最后将H,W维度展开,并移动到第一维度

Patch Merging详解

输入图像之后是一个Patch Partition,再之后是一个Linear Embedding层,这两个加在一起其实就是一个Patch Merging层。
该模块的作用是在每个Stage开始前做降采样,用于缩小分辨率,调整通道数 进而形成层次化的设计,同时也能节省一定运算量。每次降采样是两倍,因此在行方向和列方向上,间隔2选取元素。然后拼接在一起作为一整个张量,最后展开。此时通道维度会变成原先的4倍(因为H,W各缩小2倍),此时再通过一个全连接层再调整通道维度为原来的两倍。

class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

SW-MSA详解

采用W-MSA模块时,只会在每个窗口内进行自注意力计算,所以窗口与窗口之间是无法进行信息传递的。为了解决这个问题,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模块,即进行偏移的W-MSA。如下图所示,左侧使用的是刚刚讲的W-MSA(假设是第L层),那么根据之前介绍的W-MSA和SW-MSA是成对使用的,那么第L+1层使用的就是SW-MSA(右侧图)。根据左右两幅图对比能够发现窗口(Windows)发生了偏移(可以理解成窗口从左上角分别向右侧和下方各偏移了M/2(下取整)个像素)。看下偏移后的窗口(右侧图),比如对于第一行第2列的2x4的窗口,它能够使第L层的第一排的两个窗口信息进行交流。再比如,第二行第二列的4x4的窗口,他能够使第L层的四个窗口信息进行交流,其他的同理。那么这就解决了不同窗口之间无法进行信息交流的问题。
第6周学习笔记:Vision Transformer & Swin Transformer学习_第7张图片

你可能感兴趣的:(深度学习pytorch基础,transformer,学习,深度学习)