Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

Swin Transformer:使用移位窗口的分层视觉Transformer

论文地址:https://arxiv.org/abs/2103.14030

发表时间:Submitted on 25 Mar 2021 (v1), last revised 17 Aug 2021 (this version, v2)

发表期刊:

Computer Vision and Pattern Recognition(顶会)

论文代码:https://github.com/microsoft/Swin-Transformer

图像分类 目标检测 语义分割


0摘要

【挑战】将Transformer应用到图像领域主要有两大挑战:

视觉实体的规模更大,不同场景不一定适用;

图像像素的分辨率比文本单词的分辨率更高,导致计算量大。

【本文工作】为了解决这些问题,本文提出了一种用移位窗口(Shifted windows)计算的层级Transformer。

【特点】移位窗口通过将自注意力限制到不重叠的局部窗口(local window),同时也允许跨窗口(cross-window)连接,从而带来更高的效率。

1简介

计算机视觉中的建模一直由CNN主导。

Transformer在自然语言处理(NLP)任务中闻名。

本文试图扩展Transformer的适用性,使其可以作为计算机视觉的通用主干,具体见图1;

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows_第1张图片

图1.(a)所提出的Swin Transformer通过在更深的层中合并patch(以灰色显示)来构建层级特征图,并且由于仅在每个局部窗口(以红色显示)内计算自关注,因此对于输入图像大小具有线性计算复杂性。因此,它可以作为图像分类和密集识别任务的通用主干。

(b) 相比之下,先前的ViT产生单个低分辨率的特征图,并且由于全局关注的计算,输入图像大小具有二次计算复杂性。


2相关工作

CNN and variants;

Self-attention based backbone architectures;

Self-attention/Transformers to complement CNNs

Transformer based vision backbones


3方法

3.1 架构

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows_第2张图片

图3.(a)Swin-Transformer(Swin-Tiny)的架构;(b) 两个连续的Swin-T。

W-MSA:regular Window Multi-head Self-Attention;

SW-MSA:Shifted Window Multi-head Self-Attention;

W-MSA和SW-MSA分别是具有规则和移位窗口配置的多头子注意力模块。

Input: H ∗ W ∗ 3 H * W * 3 HW3

将图像分割为 4 ∗ 4 ∗ 3 4 * 4 * 3 443大小的patch (3:通道数) ;

Stage1:

每个patch都是一个token,共有 H 4 ∗ W 4 ∗ 48 \frac{H}{4} * \frac{W}{4} * 48 4H4W48个token,然后Linear embedding成C维度的特征向量(见后PatchEmbed),经过两个Swin Transformer Block(包含子注意力机制,结构见图3(b)),token变为 H 4 ∗ W 4 ∗ C \frac{H}{4} * \frac{W}{4} * C 4H4WC

Stage2:

PatchMerging合并patch(降低图像分辨率),合并相邻2∗2​的patch特征,并在4C维度进行线性变换(token数量减小2∗2=4​倍,输出维度扩大2倍);

两个Swin Transformer Block,token变为 H 8 ∗ W 8 ∗ 2 C \frac{H}{8} * \frac{W}{8} * 2C 8H8W2C

Stage3:

PatchMerging合并patch;

六个Swin Transformer Block,token变为 H 16 ∗ W 16 ∗ 4 C \frac{H}{16} * \frac{W}{16} * 4C 16H16W4C

Stage4:

PatchMerging合并patch;

两个Swin Transformer Block,token变为 H 32 ∗ W 32 ∗ 8 C \frac{H}{32} * \frac{W}{32} * 8C 32H32W8C

主要架构:

class SwinTransformer(nn.Module):def __init__(self,...,depths=[2, 2, 6, 2]):
        super().__init__()# split image into non-overlapping patches
         self.patch_embed = PatchEmbed(...)# absolute position embedding
        # ...# stochastic depth
        # ...# build layers
        self.layers = nn.ModuleList()
        # num_layers: 4 stages
        for i_layer in range(self.num_layers):
            # 包含downsample(PatchMerging)
            layer = BasicLayer(...)
            self.layers.append(layer)
​
        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
​
        self.apply(self._init_weights)def forward_features(self, x):
        x = self.patch_embed(x)
        # absolute position embedding
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)
        
        # 4 stage
        for layer in self.layers:
            x = layer(x)
​
        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x
​
    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

Swin Transformer block

下图,左:Swin Transformer block;右:标准Transformer

Swin Transformer将Transformer中的标准多头自我关注(MSA)模块替换为SW-MSA,其他层保持不变。

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows_第3张图片

代码框架

class SwinTransformerBlock(nn.Module):def __init__(self, ...):
        super().__init__()def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
​
        shortcut = x
        # 第一个 LayerNorm
        x = self.norm1(x)
        x = x.view(B, H, W, C)
        
        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C# merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        
        x = x.view(B, H * W, C)
        # 残差连接
        x = shortcut + self.drop_path(x)# LN MLP 残差连接
        x = x + self.drop_path(self.mlp(self.norm2(x)))return x

3.2. Shifted Window based Self-Attention

3.2.1 regular window

window attention 计算公式:B:相对位置编码,将注意力的计算限制在每个窗口内;

class WindowAttention(nn.Module):def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5# define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH# get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)
​
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
​
        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
​
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))
​
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)
​
        attn = self.attn_drop(attn)
​
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

3.2.2 Shifted window

就是图2的变换,具体结合3.2.3理解;

3.2.3 更高效的Sifted(cyclic-shift)

本文提出了一种更有效的批量计算方法,通过向左上方向循环移位(cyclic-shift),如图4所示。

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows_第4张图片

图2.Swin Transformer计算自我注意力的移动窗口方法的图示。

左边:规则的window;右边:移位后的shifted window(window数量多了 变为9个window);

在实际代码里,通过对特征图移位,并给Attention设置mask来间接实现。能在window数量不变,最后的计算结果等价。(通过torch.roll来实现)

翻译以下:将下图橙色箭头所指的框通过矩阵变换,变化成绿色箭头所指的样子,然后划分为红色的4个框,通过掩码可以对应回图2左的9个框;

然后再移回去;

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows_第5张图片

图4.移位窗口分区中用于自我关注的高效批处理计算方法的图示。

举个例子:

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows_第6张图片
# cyclic shift
if self.shift_size > 0:
    if not self.fused_window_process:
        shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
    else:
        x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
else:
    shifted_x = x
    # partition windows
    x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C

x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C


# reverse cyclic shift
if self.shift_size > 0:
    if not self.fused_window_process:
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
        x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
    else:
        x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
else:
    shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
    x = shifted_x

window划分及逆过程:

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size
​
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows
​
def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image
​
    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

3.2.4 相对位置偏差

其中:

M:window size;

作用见下表:

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows_第7张图片

3.3. Architecture Variants

Swin−T:C=96,layer_numbers={2,2,6,2}​

Swin−S:C=96,layer_numbers={2,2,18,2}​

Swin−B:C=128,layer_numbers={2,2,18,2}​

Swin−L:C=192,layer_numbers={2,2,18,2}​

4. Experiments

对ImageNet-1K图像分类、COCO对象检测和ADE20K语义分割进行了实验;

其他代码理解

PatchEmbed

通过二维卷积层,将stride,kernelsize设置为patch_size大小。设定输出通道来确定嵌入向量的大小。最后将H,W维度展开,并移动到第一维度;

class PatchEmbed(nn.Module):def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]
​
        self.in_chans = in_chans
        self.embed_dim = embed_dim
​
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = Nonedef forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

PatchMerging

该模块的作用是在每个Stage开始前做降采样,用于缩小分辨率,调整通道数 进而形成层次化的设计,同时也能节省一定运算量。

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows_第8张图片

↑此时通道维度会变成原先的4倍(因为H,W各缩小2倍),此时再通过一个全连接层再调整通道维度为原来的两倍

class PatchMerging(nn.Module):def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
​
        x = x.view(B, H, W, C)
        
        # 在行和列方向上间隔1选取元素
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
        
        # 归一化 使用一个线性层将它们融合为2C
        x = self.norm(x)
        x = self.reduction(x)return x

Swin-T网络架构

SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0): BasicLayer(
      dim=96, input_resolution=(56, 56), depth=2
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          dim=96, input_resolution=(56, 56), num_heads=3, window_size=7, shift_size=0, mlp_ratio=4.0
          (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=96, window_size=(7, 7), num_heads=3
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=96, out_features=96, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=96, out_features=384, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=384, out_features=96, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (1): SwinTransformerBlock(
          dim=96, input_resolution=(56, 56), num_heads=3, window_size=7, shift_size=3, mlp_ratio=4.0
          (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=96, window_size=(7, 7), num_heads=3
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=96, out_features=96, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath(drop_prob=0.018)
          (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=96, out_features=384, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=384, out_features=96, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (downsample): PatchMerging(
        input_resolution=(56, 56), dim=96
        (reduction): Linear(in_features=384, out_features=192, bias=False)
        (norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      )
    )
    (1): BasicLayer(
      dim=192, input_resolution=(28, 28), depth=2
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          dim=192, input_resolution=(28, 28), num_heads=6, window_size=7, shift_size=0, mlp_ratio=4.0
          (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=192, window_size=(7, 7), num_heads=6
            (qkv): Linear(in_features=192, out_features=576, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=192, out_features=192, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath(drop_prob=0.036)
          (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=192, out_features=768, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=768, out_features=192, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (1): SwinTransformerBlock(
          dim=192, input_resolution=(28, 28), num_heads=6, window_size=7, shift_size=3, mlp_ratio=4.0
          (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=192, window_size=(7, 7), num_heads=6
            (qkv): Linear(in_features=192, out_features=576, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=192, out_features=192, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath(drop_prob=0.055)
          (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=192, out_features=768, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=768, out_features=192, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (downsample): PatchMerging(
        input_resolution=(28, 28), dim=192
        (reduction): Linear(in_features=768, out_features=384, bias=False)
        (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
    )
    (2): BasicLayer(
      dim=384, input_resolution=(14, 14), depth=6
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          dim=384, input_resolution=(14, 14), num_heads=12, window_size=7, shift_size=0, mlp_ratio=4.0
          (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=384, window_size=(7, 7), num_heads=12
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath(drop_prob=0.073)
          (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (1): SwinTransformerBlock(
          dim=384, input_resolution=(14, 14), num_heads=12, window_size=7, shift_size=3, mlp_ratio=4.0
          (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=384, window_size=(7, 7), num_heads=12
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath(drop_prob=0.091)
          (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (2): SwinTransformerBlock(
          dim=384, input_resolution=(14, 14), num_heads=12, window_size=7, shift_size=0, mlp_ratio=4.0
          (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=384, window_size=(7, 7), num_heads=12
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath(drop_prob=0.109)
          (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (3): SwinTransformerBlock(
          dim=384, input_resolution=(14, 14), num_heads=12, window_size=7, shift_size=3, mlp_ratio=4.0
          (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=384, window_size=(7, 7), num_heads=12
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath(drop_prob=0.127)
          (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (4): SwinTransformerBlock(
          dim=384, input_resolution=(14, 14), num_heads=12, window_size=7, shift_size=0, mlp_ratio=4.0
          (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=384, window_size=(7, 7), num_heads=12
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath(drop_prob=0.145)
          (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (5): SwinTransformerBlock(
          dim=384, input_resolution=(14, 14), num_heads=12, window_size=7, shift_size=3, mlp_ratio=4.0
          (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=384, window_size=(7, 7), num_heads=12
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath(drop_prob=0.164)
          (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (downsample): PatchMerging(
        input_resolution=(14, 14), dim=384
        (reduction): Linear(in_features=1536, out_features=768, bias=False)
        (norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
      )
    )
    (3): BasicLayer(
      dim=768, input_resolution=(7, 7), depth=2
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          dim=768, input_resolution=(7, 7), num_heads=24, window_size=7, shift_size=0, mlp_ratio=4.0
          (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=768, window_size=(7, 7), num_heads=24
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath(drop_prob=0.182)
          (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (1): SwinTransformerBlock(
          dim=768, input_resolution=(7, 7), num_heads=24, window_size=7, shift_size=0, mlp_ratio=4.0
          (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=768, window_size=(7, 7), num_heads=24
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath(drop_prob=0.200)
          (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (avgpool): AdaptiveAvgPool1d(output_size=1)
  (head): Linear(in_features=768, out_features=1000, bias=True)
)

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows_第9张图片

你可能感兴趣的:(transformer,深度学习,计算机视觉)