Swin-Transformer

Swin-Transformer

文章目录

  • Swin-Transformer
    • ViT
    • Swim Transformer
      • Patches & Windows
      • patch merging layers
      • 窗口自注意力
      • Architecture Variants
    • 参考

ViT

ViT使用纯Transformer结构来做图像分类任务,它开创了Transformer能够在CV领域有效工作的先河。ViT验证了在大规模数据集上进行预训练,然后迁移到小规模数据集上,Transformer性能要比CNN好。由于缺少CNN自带的归纳偏置(平移不变形和局部性),ViT在ImageNet数据集(中型数据集)上表现没有CNN好,Transformer需要充足的图像数据学习。

我们以ViT的base模型为例来描述ViT的流程。Transformer结构不能直接处理图像,首先需要将2D的图像分块(patch),CV中的patch可以近似看做NLP中的token,每块的大小为 P ∗ P ∗ C P*P*C PPC。假设一个大小为 224 ∗ 224 ∗ 3 224*224*3 2242243的图像,每块的大小为 16 ∗ 16 ∗ 3 16*16*3 16163,那么此张图片将有 224 16 ∗ 224 16 = 14 ∗ 14 = 196 \frac{224}{16}*\frac{224}{16}=14*14=196 1622416224=1414=196个块。图像预处理将一个2D的 224 ∗ 224 ∗ 3 224*224*3 2242243的图像展平为1D的 196 ∗ 768 196*768 196768大小的向量。接下来,进行图像块嵌入(类似于NLP中的词嵌入),就是ViT论文中的 E E E E E E的维度是 768 ∗ 768 768*768 768768。映射后的向量维度仍然为 196 ∗ 768 196*768 196768。类似于BERT中的[class] token,ViT中加入了一个可以学习的嵌入,如下图中的第0位置,它经过Transformer 编码器后的输出作为图像表示 y y y,用于分类。就这样,嵌入向量就由 196 ∗ 768 196*768 196768变为 197 ∗ 768 197*768 197768。为了保持输入图像块之间的空间位置信息,对映射后的向量添加了一个位置编码信息,如下图一中的0-9数字。位置编码采用的是1-D的可学习嵌入变量,论文中实验验证2-D的位置编码和1-D的位置编码结果近似。

Swin-Transformer_第1张图片 图一:ViT示意图

Swim Transformer

Swim Transformer是特为视觉领域设计的一种分层Transformer结构。Swin 的两大特性是滑动窗口和分层表示。滑动窗口在局部不重叠的窗口中计算自注意力,并允许跨窗口连接。分层结构允许模型适配不同尺度的图片,并且计算复杂度与图像大小呈线性关系。

ViT只能够做分类,Swin Transformer借鉴了CNN的分层结构,如下图二(a),不仅能够做分类,还能够和CNN一样扩展到下游任务,比如检测,分割等。Swim Transformer不同于标准的Transformer结构,它计算不重叠窗口中的自注意力。为了解决窗口和窗口之间无连接的问题,Swin提出了移位窗口分割方法,见下图二(b),W-MSA和SW-MSA在连续的Swin Transformer blocks中交替出现,见下图二©。因此不论哪个Swim Transformer版本,都有偶数个blocks。

下图二(d)展示了Swin Transformer的tiny版本(Swin-T)。首先,它通过一个patch分割模块将输入的RGB图像分割成不重叠的patches,每个patch被看做是一个“token”,在论文中,patch size大小为 4 × 4 4 \times 4 4×4,每个patch的特征维度为 4 × 4 × 3 = 48 4 \times 4 \times 3 = 48 4×4×3=48。对于一个 H × W H \times W H×W大小的RGB图像,经过patch分割模块之后表示为 H 4 × W 4 × 48 \frac{H}{4} \times \frac{W}{4} \times 48 4H×4W×48。紧接着一个线性嵌入层将此原始值特征映射为一个任意的维度,记为 C C C。Swin Transformer block 应用到这些patch token上。线性映射加上Swin Transformer block,被称为“Stage 1”。为了得到分层表示,随着网络层数的加深,token的数量通过patch merging layers减少。第一个patch merging layer层连接每组 2 × 2 2 \times 2 2×2相邻patches的特征,然后在维度为 4 C 4C 4C的连接特征上应用线性层降维到 2 C 2C 2C。“Stage 2”,“Stage 3”和“Stage 4”由patch merging layer和Swin Transformer block组成,因此每个阶段的尺寸减少 2 2 2倍,维度增大 2 2 2倍,以至于“Stage 4”的输出特征为 H 32 × W 32 × 8 C \frac{H}{32} \times \frac{W}{32} \times 8C 32H×32W×8C
Swin-Transformer_第2张图片
图二:Swin Transformer 架构

Patches & Windows

Swin-Transformer_第3张图片
图三:patches 和windows

一张 H × W H \times W H×W大小的图中,里面包含 H × W H \times W H×W个像素。一个patch就是图像中的 N × N N \times N N×N个像素区域;一个window是由 M × M M \times M M×M个patches组成的。由上图所示,图像被分成 4 4 4个窗口,每个窗口包含 4 × 4 = 16 4 \times 4 =16 4×4=16个patches。假设每个patch的大小为 4 × 4 4 \times 4 4×4,则每个patch的向量维度为 4 × 4 × 3 = 48 4 \times 4 \times 3 = 48 4×4×3=48。每个patch可以看做NLP中的“token”,仿照NLP的词嵌入,将patch映射为维度为 C C C的向量。

下述代码展示了如何将图像如何进行patch嵌入。假设一张 224 × 224 × 3 224 \times 224 \times 3 224×224×3的图片,patch size大小为 4 × 4 4 \times 4 4×4,经过一个卷积层(第26行代码)之后的输出shape为 ( B , C , 56 , 56 ) (B, C, 56,56) (B,C,56,56),展平后两项,并对换后两项的位置,最后嵌入的输出为 ( B , 56 ∗ 56 , C ) (B,56*56,C) (B,5656,C)

class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding
    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size) 
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 
        self.img_size = img_size  # 图像尺寸
        self.patch_size = patch_size # patch大小
        self.patches_resolution = patches_resolution
        # patches 数量
        self.num_patches = patches_resolution[0] * patches_resolution[1] 

        self.in_chans = in_chans   # 输入图像通道,默认3
        self.embed_dim = embed_dim  # 映射输出通道

        # 线性映射
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # Batch, embed_dim, img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim) # 正则
        else:
            self.norm = None
    #
    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        # 假设 (H,W)=(224,224),那么(Ph,Pw)=(224/4=56, 224/4=56)
        # self.proj(x)输出shape为(B, C, 56, 56)    
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw=(56*56) C
        if self.norm is not None:
            x = self.norm(x)
        return x

patch merging layers

patch merging layers是Swim Transformer分层结构的重要组件。它连接每组 2 × 2 2 \times 2 2×2相邻patches的特征,然后在维度为 4 C 4C 4C的连接特征上应用线性层降维到 2 C 2C 2C。下图四展示了patch merging layers如何将一个 h × w × 1 h \times w \times 1 h×w×1的特征如何转换为 h 2 × w 2 × 4 \frac{h}{2} \times \frac{w}{2} \times 4 2h×2w×4。将$h \times w 特 征 特征 x$划分为大小为 2 × 2 2 \times 2 2×2的组,提取每组相同位置的特征得到 x 0 , x 1 , x 2 , x 3 x_{0}, x_{1}, x_{2},x_{3} x0,x1,x2,x3(下述代码第28-31行),合并 x 0 , x 1 , x 2 , x 3 x_{0}, x_{1}, x_{2},x_{3} x0,x1,x2,x3,通道数量则扩大 4 4 4倍(下述代码第32行),然后再通过线性层降维(下述代码第14和36行)。

Swin-Transformer_第4张图片
图四:patch merger sample

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) # 降维 4dim ---> 2dim
        self.norm = norm_layer(4 * dim) # 归一化层

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x) # 归一化
        x = self.reduction(x) # 降维

        return x

窗口自注意力

标准的Transformer架构是全局自注意力,它计算某个token和其他token之间的自注意力,计算复杂度和token的数量呈平方关系。视觉图像的token数量要多于语言中的单词token数量,Transformer在视觉中会耗费更多的资源,尤其对于高质量图像,计算复杂度会非常大。基于此种情况,Swim Transformer采用基于窗口的自注意替换标准的全局注意力。

将一张patches数量为 h × w h \times w h×w的图像拆分成不重叠的窗口,每个窗口包含 M × M M \times M M×M个patches。我们先回忆一下标准Transformer中的多头自注意力。假设输入为 x x x,将 x x x进行线性嵌入得到 Q , K , V Q, K, V Q,K,V三个向量, Q Q Q K K K两个向量相乘计算得到Attention,然后Attention与向量 V V V相乘之后再线性映射得到输出。假设patches的数量为 N N N,通道数为 C C C,那么两次线性计算复杂度为 4 N C 2 4NC^{2} 4NC2 Q , K , V Q, K, V Q,K,V的两次矩阵计算的复杂度为 2 N 2 C 2N^{2}C 2N2C。那么对于标准的多头自注意力,它的计算复杂度为 Ω ( M S A ) = 4 h w C 2 + 2 ( h w ) 2 C \Omega\left( MSA \right) = 4hwC^{2} + 2\left( hw \right)^{2}C Ω(MSA)=4hwC2+2(hw)2C;一个窗口的自注意力计算复杂度为 4 M 2 C 2 + 2 M 4 C 4M^{2}C^{2} + 2M^{4}C 4M2C2+2M4C,此张图片总共有 h M × h M \frac{h}{M} \times \frac{h}{M} Mh×Mh个窗口,那么总的基于窗口的自注意力的计算复杂度为 Ω ( W − M S A ) = 4 h w C 2 + 2 M 2 h w C \Omega\left( W-MSA \right) = 4hwC^{2} + 2M^{2} hw C Ω(WMSA)=4hwC2+2M2hwC

对于 224 × 224 224 \times 224 224×224大小的图片,每一个patches的大小为 4 × 4 4 \times 4 4×4,那么总共有 56 × 56 56 \times 56 56×56个patches。论文中默认 M = 7 M=7 M=7 Ω ( M S A ) \Omega\left( MSA \right) Ω(MSA) Ω ( W − M S A ) \Omega\left( W-MSA \right) Ω(WMSA)计算复杂度相差在矩阵相乘的部分, 2 × 56 × 56 h w C 2 \times 56 \times 56 hwC 2×56×56hwC 2 × 7 2 h w C 2 \times 7^{2} hwC 2×72hwC的近60倍。随着图片的尺寸越大,这个差距会越大。

论文在计算自注意力时引入了相对位置偏置(relative position bias),论文实验表明,相对位置偏置在ImageNet,CoCo和ADE20k数据集上的表现要优于不加偏置和使用绝对位置偏置。下述代码展示了带有相对位置偏置的窗口多头自注意力的前向过程。它支持窗口自注意力和移动窗口自注意力。窗口自注意力计算包含三个方面,常规多头自注意力,相对位置偏置的计算和移动窗口的掩码计算。常规多头自注意力有Transformer的基础就很好理解,难点在于相对位置偏置的计算和移动窗口的掩码计算。
A t t e n t i o n ( Q , K , V ) = S o f t M a x ( Q K T / d + B ) V Attention(Q,K,V) = SoftMax(QK^{T} / \sqrt{d}+ B) V Attention(Q,K,V)=SoftMax(QKT/d +B)V

# 通道注意力计算
def forward(self, x, mask=None):
    """
    Attention(Q,K,V) = SoftMax(QK^{T}/sqrt(d) + Bias)V
    x: 输入特征  shape: (num_windows*B, N, C)
    mask: 掩码
    """
    B_, N, C = x.shape  # N=Wh*Ww 窗口里面的patches数量
    # qkv.shape: (3, num_windows*B, self.num_heads, N, C // self.num_heads)  
    qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2] # shape: (num_windows*B, self.num_heads, N, C // self.num_heads) 
    
    # self.scale对应于公式中的sqrt(d)
    q = q * self.scale
    attn = (q @ k.transpose(-2, -1)) # QK^{T}/sqrt(d)  atten.shape: (num_windows*B, self.num_heads, N,  N)

    # 相对位置偏置
    relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
        self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH  nH是head的数量
    #relative_position_bias.shape=(nH, Wh*Ww, Wh*Ww)
    relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
    attn = attn + relative_position_bias.unsqueeze(0)  # QK^{T}/sqrt(d) + Bias
    
    # 掩码
    if mask is not None:
        nW = mask.shape[0]
        attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
        attn = attn.view(-1, self.num_heads, N, N)
        attn = self.softmax(attn)
    else:
        attn = self.softmax(attn) # SoftMax(QK^{T}/sqrt(d) + Bias)

    attn = self.attn_drop(attn)

    # x.shape=(num_windows*B, N, C)
    x = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # SoftMax(QK^{T}/sqrt(d) + Bias)V
    x = self.proj(x) # 映射
    x = self.proj_drop(x) # drop
    return x

绝对位置编码是在进行自注意力计算之前为每个token添加一个可学习的参数,相对位置编码,是在进行自注意力计算时,在计算过程中添加一个可学习的相对位置参数。

相对位置偏置 B ∈ R M 2 × M 2 B \in \mathbb{R}^{M^{2} \times M^{2}} BRM2×M2,每一个轴的取值范围是 [ − M + 1 , M − 1 ] [-M+1, M-1] [M+1,M1]。计算自注意力时,每个token都要与其他位置上的token计算 Q K QK QK值。对于一个大小为 2 × 2 2\times2 2×2的窗口,位置1上的patch要与位置1,2,3,4的patch计算 Q K QK QK值,位置2上的patch要与位置1,2,3,4上的patch计算 Q K QK QK值,… ,那么其他位置相对于当前位置都有一个偏移量。下图5中展示了relative_coords(下述代码第8行)其他位置相当于当前位置的偏移量(按列看),为了便于后续的计算,对每个元素都加上偏移量,使其从零开始,如下述代码第9和第10行。由于(0,1)和(1,0),(-1,0)和(0,-1)它们取和后的总偏移量结果一样,因为对某一列坐标进行乘法变换,如下述代码第11行,最后再取和得到总的偏移量relative_position_index。至此,相对位置的下标取值范围为 [ 0 , 8 ] [0,8] [0,8],可由一个 ( 2 M − 1 ) ∗ ( 2 M − 1 ) (2M-1)*(2M-1) (2M1)(2M1)大小的矩阵表示,参数化这个更小尺寸的偏置矩阵 B ^ ∈ R ( 2 M − 1 ) × ( 2 M − 1 ) \hat{B} \in \mathbb{R}^{\left( 2M-1\right) \times \left( 2M-1\right)} B^R(2M1)×(2M1),那么 B B B的值就可以从 B ^ \hat{B} B^中提取。

 # define a parameter table of relative position bias
 self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
 
 # get pair-wise relative position index for each token inside the window
 # 以下用Wh * Ww 替代 self.window_size[0] * self.window_size[1]
 coords_h = torch.arange(self.window_size[0]) # [0,1,...,Wh-1]
 coords_w = torch.arange(self.window_size[1]) # [0,1,...,Ww-1]
 coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
 coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
 relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
 relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
 relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
 relative_coords[:, :, 1] += self.window_size[1] - 1
 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
 relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
 # 将relative_position_index注册为一个不参与网络学习的变量。
 self.register_buffer("relative_position_index", relative_position_index)

 # 使用截断正态分布中提取的值填充输入张量。
 trunc_normal_(self.relative_position_bias_table, std=.02)

 # forward函数中相对未知的偏置
 '''self.relative_position_index是计算出不可学习的量  第17行
 self.relative_position_index.shape=(Wh*Ww, Wh*Ww) 第15行
 self.relative_position_bias_table.shape=(2*Wh-1 * 2*Ww-1, nH) 第2行
 self.relative_position_index矩阵中的所有值都是从self.relative_position_bias_table中取的
 '''
 relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
        self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH  nH是head的数量

Swin-Transformer_第5张图片

Swim 计算每个窗口中的自注意,窗口与窗口之间无计算,失去了Transformer处理全局信息的特征。为此,Swim Transformer提出了移位窗口分割方法。假设有 8 × 8 8 \times 8 8×8 patches的图片,其中一个窗口包含 4 × 4 4 \times 4 4×4个patches,那么将有 2 × 2 2 \times 2 2×2个窗口。现在将左上角的窗口向左下移位 2 2 2个patches,四个窗口将被重新划分为 9 9 9个大小不一的窗口,如图六所示,只有标号为4的窗口和原窗口大小一致。

最直接的想法是对小窗口进行padding,并在计算的时候屏蔽掉填充的值。但是,自注意力计算将由四个被扩展到九个,计算多了2.25倍。为了不增加计算量,论文中提出了循环移位(cyclic-shift)算法,如图六所示,将编号3,6的窗口移位到编号5,8的窗口下面,将编号0,1的窗口移位到编号6,7的窗口左面,将编号为0的窗口,从左上角移位到右下角。这样就可以重新拼凑出 2 × 2 2 \times 2 2×2 (4,(7,1),(3,5),(0,2,6,8))个窗口。拼凑出的窗口在原图中属于不同的位置,不相连,以标号为0,2,6,8窗口组成的大窗口为例,这四个小窗口分别位于原图的四个顶点,关联性极低,因此,在计算窗口注意力时,需要掩码机制,只能计算相同子窗口的自注意力,不同窗口的自注意力结果要为0。标号为0,2,6,8窗口,在计算窗口自注意力时,窗口0中的每一个patch分别需要和窗口0,2,6,8中的每一patch进行自注意力计算,那么窗口0中的patch与窗口0中的patch的自注意力是有用的,但是窗口0中的patch与窗口2,6,8中的patch的自注意力需要设为0。我们回忆一下Attention的计算公式, A t t e n t i o n ( Q , K , V ) = S o f t M a x ( Q K T / d + B ) V Attention(Q,K,V) = SoftMax(QK^{T} / \sqrt{d}+ B) V Attention(Q,K,V)=SoftMax(QKT/d +B)V,自注意力计算最后需要Softmax函数。在不同窗口的自注意力值上添加 − 100 -100 100(下图代码第20行,mask赋值-100;第27行,将mask添加到自注意力值上,然后再进行softmax计算),在softmax计算过程中, − 100 -100 100会无限趋近于0,从而达到归0的效果。

Swin-Transformer_第6张图片
图六:窗口移动

if self.shift_size > 0:
    # calculate attention mask for SW-MSA
    H, W = self.input_resolution
    img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
    h_slices = (slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None))
    w_slices = (slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None))
    # 给每个窗口上索引
    cnt = 0
    for h in h_slices:
        for w in w_slices:
            img_mask[:, h, w, :] = cnt
            cnt += 1
    # 拆分窗口
    mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
    mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # nW, window_size * window_size
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW, window_size * window_size, window_size * window_size
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) # 
else:
    attn_mask = None


if mask is not None:
    nW = mask.shape[0]
    attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) # 加mask
    attn = attn.view(-1, self.num_heads, N, N)
    attn = self.softmax(attn)
else:
    attn = self.softmax(attn)

下图展示了循环移位之后的窗口组成和mask值的分布情况。

Swin-Transformer_第7张图片
图七:自注意力mask

Architecture Variants

Swin Transformer有四种形式,分别命名为Swin-T,Swin-S,Swin-B和Swin-L。以Swin-T为基础模型版本,Swin-T,Swin-S,和Swin-L分别是基础模型的 0.25 × 0.25\times 0.25×, 0.5 × 0.5\times 0.5× 2 × 2\times 2×倍。这四种模型的架构如图8所示。

Swin-Transformer_第8张图片
图8:architecture

ViT模型将Transformer结构应用到视觉领域,但是仍然还受限于图片的尺寸大小。Swin引入移动窗口和分层结构,使得自注意力在视觉领域的计算复杂度能与图片大小成线性关系。Swin吸取了CNN和Transformer的优点,在ImageNet-1k的数据集上也能取得SOTA效果,相比于ViT模型,降低了数据的需求量。

参考

  1. Swin-Transformer code
  2. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
  3. The Question about the mask of window attention
  4. An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

你可能感兴趣的:(机器学习和深度学习之旅,transformer,深度学习,机器学习,计算机视觉,人工智能)