4、Swin Transformer:视觉Transformer的革新之路

目录

一、论文名称 

二、背景与动机

三、卖点与创新

四、具体实现细节

1、模型架构

2、Patch Partition

3、Patch Merging

4、Swin Transfomer Block

W-MSA

SW-MSA

五、对比ViT

六、一些资料


一、论文名称 

原论文地址:

         Swin Transformer: Hierarchical Vision Transformer using Shifted Windowsicon-default.png?t=N7T8https://arxiv.org/abs/2103.14030
官方开源代码地址:

https://github.com/microsoft/Swin-Transformericon-default.png?t=N7T8https://github.com/microsoft/Swin-Transformer

二、背景与动机

        在深度学习领域,卷积神经网络(CNN)在计算机视觉任务中一直占据主导地位,尤其是在图像识别、目标检测和语义分割等方面表现出色。然而,随着Transformer架构在自然语言处理领域的巨大成功,研究者们开始探索其在视觉领域的应用可能性。尽管CNN能够有效捕获局部特征,但对全局信息的建模能力相对较弱,且层级间的特征交互较为有限。

        Swin Transformer的诞生正是在这种背景下应运而生,其主要动机在于设计一个既能保留Transformer强大的全局建模能力,又能兼顾计算效率和模型参数量的视觉Transformer模型,以实现对复杂视觉场景的高效理解和表达。

        而将Transformer从语言调整到视觉的挑战来自两个领域之间的差异:

1.视觉实体的大小差异很大,NLP对象的大小是标准固定的。
2.图像中的像素与文本中的单词相比具有很高的分辨率,而CV中使用Transformer的计算复杂度是图像尺度的平方,这会导致计算量过于庞大。

三、卖点与创新

        为解决效率和全局局部信息捕捉的问题,Swin-Transformer提出了一种采用滑动窗口策略的分层Transformer架构。该模型在计算注意力时,巧妙地将注意力集中于一系列不重叠的局部窗口内,同时通过跨窗口连接机制保留了对全局上下文信息的获取能力,从而显著提升了计算效率。这种层级设计赋予了模型在多种尺度上灵活建模的能力,并且其在处理图像大小变化时展现出线性的时间复杂度特性。得益于此,Swin Transformer能够适应并有效应对各类广泛的视觉任务挑战。具体的:

  1. 移位窗口策略(Shifted Window Strategy): Swin Transformer最大的创新在于引入了“移位窗口”策略。传统的Transformer通常采用全局自注意力机制,计算复杂度随输入尺寸线性增长。而在Swin Transformer中,将图像分割为非重叠或有重叠的小窗口,注意力机制在窗口内计算,并在不同层间进行窗口内容的移位操作,既保证了对局部区域的深入挖掘,又能在不增加过多计算负担的前提下考虑更广泛的上下文信息。

  2. 层级多尺度架构(Hierarchical Multi-Stage Architecture): 受到CNN的启发,Swin Transformer构建了一种分层多阶段的结构,通过逐步增大窗口大小并在不同层次上提取特征,实现了从局部细节到全局上下文的多层次表示,有效地模拟了CNN中的空间下采样过程。

  3. 图像窗口级线性时间复杂度(Linear Complexity in Computing Attention): 通过对窗口内的自注意力计算,Swin Transformer将原本全局的自注意力复杂度降低到了与图像窗口大小相关的线性级别,大大提升了模型在大规模图像数据上的运行效率。

  4. 无位置编码(No Positional Embeddings): 在Swin Transformer中,由于使用了固定窗口的方式,位置信息已经蕴含在窗口内的相对位置中,因此无需额外添加绝对位置编码,简化了模型结构。

四、具体实现细节

1、模型架构

Swin-Transformer的基础流程如下:

  • 输入一张图片(H\times W\times 3
  • 图片经过Patch Partition层进行图片分割
  • 分割后的数据经过Linear Embedding层进行特征映射
  • 将特征映射后的数据输入Swin Transformer Block,并与Linear Embedding一起被称为第1阶段。
  • 与阶段1不同,阶段2-4在输入模型前需要进行Patch Merging进行下采样,产生分层表示。
  • 最终将经过阶段4的数据经过输出模块(包括一个LayerNorm层、一个AdaptiveAvgPool1d层和一个全连接层)进行分类。

2、Patch Partition

        Patch Partition是模型对输入图像进行预处理的一种重要操作。该操作的主要目的是将原始的连续像素图像分割成一系列固定大小的图像块(patches),以便进一步转化为Transformer可以处理的序列数据。

具体步骤如下:

  1. 划分Patch:

    输入图像首先被划分为一个个非重叠的小图像块(patches)。例如,在Swin Transformer中,通常将高和宽维度上均匀分割成大小为P × P的patch。
  2. Flatten to 1D Sequence:

    这些二维的patch会被展平成一维向量,这样每个patch就变成了一个单一的向量表示,其中包含了该patch内所有像素的信息。假设图像的通道数为C,则每个patch展平后的维度会是P × P × C
  3. Linear Embedding:

    展平后的patch序列随后会经过一个线性层(通常是全连接层),将其映射到一个新的隐藏维度空间。这个过程可以视为对patch的初步特征提取,生成的新向量维度一般记作D,即每个patch现在是一个D维的向量。

        通过Patch Partition和后续的Linear Embedding,原本结构化的图像数据转换成了Transformer可以高效处理的一系列固定长度的向量序列,为后续的自注意力机制提供输入。这种处理方式使得Transformer可以在不依赖卷积神经网络(CNN)的前提下,直接从图像patch级别的局部信息构建全局上下文理解。

import torch
import torch.nn as nn

class PatchPartition(nn.Module):
    def __init__(self, patch_size=4, in_chans=3, embed_dim=96):
        """
        Initialize the Patch Partition module.
        
        Args:
            patch_size (int): The size of each patch.
            in_chans (int): Number of input channels.
            embed_dim (int): The embedding dimension.
        """
        super().__init__()
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        """
        Apply patch partition to input tensor.
        
        Args:
            x (torch.Tensor): Input tensor of shape (B, C, H, W).

        Returns:
            torch.Tensor: Patch embeddings tensor of shape (B, embed_dim, H/P, W/P),
                          where P is the patch size.
        """
        # Use a convolutional layer to perform both the partition and embedding
        x = self.proj(x)  # (B, embed_dim, H/P, W/P)
        return x

# Example usage
B, C, H, W = 2, 3, 32, 32  # Batch size, channels, height, width
patch_size = 4
embed_dim = 96
x = torch.randn(B, C, H, W)
patch_partition = PatchPartition(patch_size=patch_size, in_chans=C, embed_dim=embed_dim)
patches = patch_partition(x)

print(f"Input shape: {x.shape}")
print(f"Patch embeddings shape: {patches.shape}")

注意:在实际操作中,Patch PartitionLinear Embedding通过一个二维的卷积层输出通道为Embedding维度卷积核大小为patch_sizestride大小为patch_size)实现

3、Patch Merging

        Patch Merging层主要是进行下采样,产生分层表示。 Patch Merging 是一种减少序列长度并增加每个补丁表示中通道数的操作。借用别人的一张图:

实际上,Patch Merging模块接收了一个输入张量 x,其中 B 是批次大小,L 是序列长度(pathch的数量,等于高度 H 乘以宽度 W),C 是通道数。补丁合并操作包括以下步骤:

  1. 将输入张量从 (B, H*W, C) 重塑为 (B, H, W, C)
  2. 对相邻的补丁进行分组,创建大小为 4*C 的新特征向量。
  3. 将这些向量通过一个归一化层,然后使用一个线性映射将通道数减半,从 4*C 到 2*C

在实际的 Swin Transformer 实现中,Patch Merging 是随着网络深度的增加而使用的操作,有助于构建不同尺度的特征层次,并最终生成用于任务的特征表示。在上述代码中,新的序列长度变为原来的四分之一(因为每四个patch合并为一个),通道数加倍。

import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchMerging(nn.Module):
    def __init__(self, input_resolution, in_chans):
        super().__init__()
        self.input_resolution = input_resolution
        self.reduction = nn.Linear(4 * in_chans, 2 * in_chans, bias=False)
        self.norm = nn.LayerNorm(4 * in_chans)

    def forward(self, x):
        """
        Apply patch merging to input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (B, H*W, C).

        Returns:
            torch.Tensor: Merged patch tensor of shape (B, H/2*W/2, 2*C).
        """
        B, L, C = x.shape
        H, W = self.input_resolution
        assert L == H * W, "Input feature has wrong size"

        # Reshape
        x = x.view(B, H, W, C)

        # Merge patches
        x0 = x[:, 0::2, 0::2, :]  # Top-left
        x1 = x[:, 1::2, 0::2, :]  # Bottom-left
        x2 = x[:, 0::2, 1::2, :]  # Top-right
        x3 = x[:, 1::2, 1::2, :]  # Bottom-right
        x = torch.cat([x0, x1, x2, x3], -1)  # Concat along channel dimension

        # Flatten
        x = x.view(B, -1, 4 * C)

        # Normalize and reduce
        x = self.norm(x)
        x = self.reduction(x)

        return x

# Example usage
B, H, W, C = 2, 32, 32, 64  # Batch size, height, width, channels
input_resolution = (H, W)
x = torch.randn(B, H * W, C)
patch_merging = PatchMerging(input_resolution=input_resolution, in_chans=C)
merged_patches = patch_merging(x)

print(f"Input shape: {x.shape}")
print(f"Merged patch shape: {merged_patches.shape}")

4、Swin Transfomer Block

W-MSA

        W-MSA(Window-based Multi-head Self-Attention)是一种特殊的自注意力机制,用于在局部窗口中计算注意力权重。它是Swing Transformer模型中用于处理大尺寸图像的关键组件之一。

        W-MSA通过将输入的特征图划分为不重叠的局部窗口,然后在每个局部窗口内计算自注意力权重。这种窗口化的设计允许模型在处理大图像时具有较低的计算复杂度,并且对于长距离的依赖关系也能够获取有效的信息。

        W-MSA 的工作原理:

  1. 窗口划分:首先,将输入的特征图(例如,来自前一层的输出)划分为多个不重叠的小窗口。

  2. 局部注意力计算:在每个小窗口内独立地应用多头自注意力机制。这意味着,每个窗口内的像素只会计算与同一窗口内其他像素的注意力分数。这种方法有效地减少了计算量,因为它不需要计算图像中所有像素之间的注意力分数。

  3. 多头注意力:与标准的 Transformer 模型类似,W-MSA 在每个头上独立地计算注意力分数,然后将这些头的输出拼接起来。这允许模型在不同的表示子空间中捕获信息。

        W-MSA 的优势:

  • 降低计算复杂度:通过限制自注意力在小窗口内计算,W-MSA 显著降低了与图像大小成平方关系的计算复杂度。
  • 捕获局部信息:在窗口内进行注意力计算有助于捕获局部特征,这在图像处理中是很重要的。

        ⚠️:普通的MSA和W-MSA的计算量对比参考:这里

        简易代码示例:

import torch
from torch import nn

# 假设我们已经有了基本的MultiHeadSelfAttention模块(MHA)
class MultiHeadSelfAttention(nn.Module):
    # ...具体的多头自注意力实现...

class WindowBasedMSA(nn.Module):
    def __init__(self, window_size, num_heads, embed_dim, attn_drop=0., proj_drop=0.):
        super(WindowBasedMSA, self).__init__()
        self.mha = MultiHeadSelfAttention(num_heads=num_heads, embed_dim=embed_dim, dropout=attn_drop)
        self.norm = nn.LayerNorm(embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.drop = nn.Dropout(proj_drop)
        self.window_size = window_size  # 窗口大小

    def forward(self, x, H, W):  # 增加H和W作为输入参数,表示图像原始的高度和宽度
        B, N, C = x.shape  # (batch_size, sequence_length, embed_dim)

        # 计算每个维度上的窗口数量,并确保最后一个窗口可能不满的情况
        win_H, win_W = (H // self.window_size) * self.window_size, (W // self.window_size) * self.window_size
        pad_h = (H - win_H) // 2
        pad_w = (W - win_W) // 2

        # 对输入进行填充以适应窗口划分
        x = F.pad(x, (0, 0, pad_w, pad_w + (W-win_W)%self.window_size, pad_h, pad_h + (H-win_H)%self.window_size))

        # 将序列划分为窗口
        x_windows = x.view(B, H // self.window_size, self.window_size,
                           W // self.window_size, self.window_size, C).permute(0, 1, 3, 2, 4, 5).reshape(-1, self.window_size**2, C)

        # 对每个窗口应用多头自注意力
        x_windows_attended = self.mha(x_windows)

        # 恢复窗口划分前的形状
        x = x_windows_attended.reshape(B, H // self.window_size, W // self.window_size, self.window_size, self.window_size, C)
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H * W, C)

        # 移除填充部分
        x = x[:, pad_h:(H+pad_h)*self.window_size:window_size, pad_w:(W+pad_w)*self.window_size:window_size]

        x = self.norm(x)
        x = self.proj(x)
        x = self.drop(x)

        return x
SW-MSA

        由于W-MSA只能关注窗口本身的内容,而不允许跨窗口连接,窗口与窗口之间是无法进行信息传递的。故而引入SW-MSA通过移位窗口的方式,引入跨窗口连接的同时保持非重叠窗口的高效计算。简单步骤:

  1. 窗口循环位移:然后,将特征图的窗口沿着一个方向(通常是向右和向下)进行循环位移,位移的大小通常是窗口大小的一半。这样,原本位于窗口边缘的像素现在处于新窗口的中心位置。
  2. 执行 Shifted W-MSA:在位移后的窗口上执行自注意力计算,这样,原本属于不同窗口的像素现在可以在新的窗口中相互作用。
  3. 逆向循环位移:注意力计算完成后,将位移的窗口再次逆向位移回原来的位置,恢复原始的特征图布局。

        SW-MSA 的优势:

  • 增强跨窗口连接:SW-MSA 通过在不同层之间交替使用 W-MSA 和 SW-MSA,允许来自相邻窗口的像素进行交互,提高了模型的表达能力。
  • 无需增加计算复杂度:SW-MSA 通过循环位移而不是增加额外的注意力计算来实现跨窗口的连接,从而保持了计算效率。

        实现细节:

        如下图左所示为第 l层使用W-MSA的方式,而在下一层为 SW-MSA的方式(如右图所示),再联合Norm等层合在一起作为一个 Swin Transformer Block模块。两幅图进行对比可以发现:右图相对于左图进行了偏移,长宽分别偏移了\frac{M}{2} 个像素单位(每个窗口为 M\times M 像素)。

        可以看出,偏移后的图像窗口由4个变为了9个。为了提高计算的效率,作者提出了一种更有效的计算方法,即向左上方向循环移位,如下图所示。在此转换之后,一些窗口可能由特征映射中不相邻的几个子窗口组成(这些不连续的部分是不应该参与注意力计算的),因此采用mask机制(NLP中的masking 屏蔽不应该需要的信息)将注意力计算计算限制在每个子窗口内。

        在实际代码里,是通过对特征图移位,并给Attention设置mask来间接实现的。能在保持原有的window个数下,最后的计算结果等价。先把上面的一行移到最下面,再把最左边的一列移动到最右边。借用别人的一张图:

4、Swin Transformer:视觉Transformer的革新之路_第1张图片

        代码示例:      

import torch
import torch.nn as nn
import torch.nn.functional as F

class ShiftedWindowMSA(nn.Module):
    def __init__(self, window_size, num_heads, embed_dim, shift_size=0):
        super(ShiftedWindowMSA, self).__init__()
        self.window_size = window_size
        self.shift_size = shift_size
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        
        self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        # 初始化attention mask
        self.attn_mask = None
        self._calculate_window_mask()

    def _calculate_window_mask(self):
        h, w = self.window_size, self.window_size
        attn_mask = torch.zeros((1, h, w, h, w))
        for i in range(h):
            for j in range(w):
                attn_mask[0, i, j, max(0, i - self.shift_size): min(h, i + self.shift_size + 1),
                            max(0, j - self.shift_size): min(w, j + self.shift_size + 1)] = 1.0
        attn_mask = attn_mask.unsqueeze(1) - torch.eye(h * w).unsqueeze(1).unsqueeze(3).repeat(1, 1, h * w, 1)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float('-inf')).masked_fill(attn_mask == 0, float(0.0))
        self.register_buffer('attn_mask', attn_mask)

    def _window_partition(self, x, window_size):
        B, H, W, C = x.shape
        win_height, win_width = window_size, window_size
        pad_h = (H % win_height) // 2
        pad_w = (W % win_width) // 2
        x = F.pad(x, (0, 0, pad_w, pad_w, pad_h, pad_h))
        x = x.view(B, H // win_height, win_height,
                   W // win_width, win_width, C)
        windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_height * win_width, C)
        return windows

    def forward(self, x):
        windows = self._window_partition(x, self.window_size)

        qkv = self.qkv_proj(windows).reshape(-1, self.window_size * self.window_size, self.num_heads, self.embed_dim // self.num_heads * 3).permute(0, 2, 1, 3)
        q, k, v = qkv.chunk(3, dim=-1)

        attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.embed_dim // self.num_heads)
        # 应用attention mask
        attn = attn + self.attn_mask.unsqueeze(0)
        attn = nn.functional.softmax(attn, dim=-1)
        attn_windows = attn @ v

        attn_windows = attn_windows.reshape(-1, self.num_heads, self.window_size * self.window_size, self.embed_dim // self.num_heads).permute(0, 2, 1, 3).contiguous()
        attn_windows = attn_windows.view(-1, x.shape[1] * x.shape[2], self.embed_dim)

        x = self.out_proj(attn_windows)

        return x

上述代码中,_calculate_window_mask函数用于生成一个 attention mask,该mask将阻止每个窗口内特征与其他窗口内的特征进行注意力计算。在forward函数中,我们对attention scores加上这个mask,并在softmax操作之前应用它。

五、对比ViT

        采样方式:Swin-Transformer开始采用4倍下采样的方式,后续采用8倍下采样,最终采用16倍下采样;ViT则一开始就使用16倍下采样
        目标检测机制:Swin-Transformer中,通过4倍、8倍、16倍下采样的结果分别作为目标检测所用数据,可以使网络以不同感受野训练目标检测任务,实现对大目标、小目标的检测;ViT则只使用16倍下采样,只有单一分辨率特征

        注意力复杂度:Swin Transformer通过合并更深层的图像块(以灰色显示)来构建分层特征图,并且由于只在每个局部窗口(以红色显示)内计算注意力,因此对于输入图像大小具有线性计算复杂度。因此,它可以作为图像分类和密集识别任务的通用backbone;相比之下,以前的ViT产生单一低分辨率的特征图,并且由于计算全局的自我注意,对于输入图像大小具有二次计算复杂度。

        Swin Transformer的基本模块是基于移动窗口的自注意力层(Swin Transformer Block)。每个块包含两个子层:一个多头自注意力层,仅在当前窗口内计算自注意力;一个 MLP 层用于进一步提升特征表示。在每一阶段,窗口大小会逐渐扩大,并利用移位操作使得相邻窗口的特征能够交互。

        整个模型由连续的几个 stages 组成,每个 stage 包含多个相同的 Swin Transformer Block,并在stage之间进行patch合并以减少分辨率,从而获取更抽象的特征表示。

        总结来说,Swin Transformer巧妙地融合了CNN和Transformer的优点,突破了传统Transformer在视觉任务中的局限性,为视觉领域的研究开辟了新的方向,展现出卓越的性能和广泛的应用前景。

六、一些资料

Swin-Transformer详解_swin transformer-CSDN博客文章浏览阅读2.4k次,点赞5次,收藏18次。Swin-Transformer是2021年微软研究院发表在ICCV上的一篇文章,并且已经获得的荣誉称号。虽然在图像分类方面的结果令人鼓舞,但是由于其低分辨率特性映射和复杂度随图像大小的二次增长,其结构不适合作为密集视觉任务或高分辨率输入图像的通过骨干网路。为了最佳的精度和速度的权衡,提出了Swin-Transformer结构。_swin transformerhttps://blog.csdn.net/qq_36758270/article/details/130833560Swin Transformer详解-CSDN博客文章浏览阅读2.5w次,点赞27次,收藏136次。之前transformer主要用于NLP领域,现在也应用到了CV领域。_swin transformerhttps://blog.csdn.net/qq_43349542/article/details/118585880

你可能感兴趣的:(AIGC论文笔记,深度学习,深度学习,人工智能)