【代码解析】mmaction2: Video Swin Transformer

目录

  • 1 网络结构
    • 1.1 代码
    • 1.2 解析
  • 2 实验结果

论文:https://arxiv.org/abs/2106.13230
源码:https://github.com/SwinTransformer/Video-Swin-Transformer

【代码解析】mmaction2: Video Swin Transformer_第1张图片

1 网络结构

在DHW三维上构建window进行self-attention提取,所以同时提取了spatial和temporal两个维度的关联性

1.1 代码

Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py

1.2 解析

  • SwinTransformer3D

    • patch_embed: PatchEmbed3D
      将输入三维信号切分成多个3d-patch,patch_size默认(2,4,4),对每个patch使用3d-conv进行特征提取并降采样
      • padding:对无法被patch_size整除维度进行填零padding
      • self.proj = conv3d(3, 96, kernel_size = patch_size, stride=patch_size):对输入特征进行三维卷积,即对每个patch_size大小窗口的输入进行特征提取,每个patch_size输出一个96维特征
      • norm(optional): fllatten + transpose + layer_norm(对channel维度进行norm,即对每个patch_size的96维特征进行归一化)+transpose
  • pos_drop: nn.Drop

  • self.layers : depths [2, 2, 6, 2] 多个BasicLayer进行串联

    • BasicLayer 进一步对上层输出信号切分成多个3d-window,window_size默认(8,7,7),对patch和patch之间的特征关联进行信息提取
      • get_window_size((D,H,W), window_size=(8,7,7), shift_size=(4,3,3))
      • rearrange(x, 'b c d h w -> b d h w c')
      • self.attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) 根据输入尺度和window_size生成transformer中的mask,对非自身window的特征关联信息进行抑制
        【代码解析】mmaction2: Video Swin Transformer_第2张图片
    • nn.ModuleList(SwinTransformerBlock3D(for i in range(depth)])多个SwinTransformerBlock3D进行串联 (B,D,H,W,C)
      在这里插入图片描述
      • nn.LayerNorm
      • F.pad
      • torch.roll(optional)
      • x_windows = window_partition: shape (B*nW, Wd*Wh*Ww, C) window切分
      • attn_windows = self.attn(x_windows, mask=attn_mask): WindowAttention3D 对window内部进行self-attention特征提取, shape (B*nW, Wd*Wh*Ww, C)
        • nn.Linear(dim, dim * 3, bias=qkv_bias) 将输入升维三倍
        • q, k, v = qkv[0], qkv[1], qkv[2] 提取K,Q,V特征
        1. q * self.scale = head_dim ** -0.5根据head_num进行缩放,防止multi-head大小对信号量影响过大
        2. attn = q @ k.transpose(-2, -1) 内积
        • attn + relative_position_bias: relative_position_bias_table 加入位置编码(防止特征顺序对transformer模块失效,不参与学习)
        • attn.view(B_ // nW, nW, self.num_heads, N, N) + mask 加入关联特征激活/抑制mask,这里mask就是之前提取的self.attn_mask
        • self.softmax(attn) + self.attn_drop(attn) Transformer标准模块
        • x = (attn @ v) Transformer标准模块
        • self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) Transformer标准模块
        • x = shortcut + self.drop_path(x) FFN模块
  • downsample: PatchMerging 对输出特征进行重排,H和W变为1/2(不对D进行降采样),channel会变成4倍【代码解析】mmaction2: Video Swin Transformer_第3张图片

    • 对H和W进行间隔采样
    • norm: nn.LayerNorm
    • nn.Linear(4 * dim, 2 * dim) channel降维
  • rearrange(x, 'b d h w c -> b c d h w')

  • rearrange + norm + rearrange

Swin-trans参数膨胀
inflate_weights

  • patch_embed 中的conv3d选择直接膨胀初始化conv2d
  • relative_position_bias_table 两种:膨胀初始化、中心初始化

2 实验结果

【代码解析】mmaction2: Video Swin Transformer_第4张图片

你可能感兴趣的:(深度学习,深度学习)