MViT-code

MViT模型

MViT-code_第1张图片

1.多头池化注意力(MHPA)

Multi Head Pooling Attention是本文的核心,它使得多尺度变换器已逐渐变化的时空分辨率进行操作。与原始的多头注意力(MHA)不同,在原始的多头注意力中,通道维度和时空分辨率保持不变,MHPA将潜在张量序列合并,以减少参与输入的序列长度(分辨率)。如下图所示,
MViT-code_第2张图片
Transformer只能处理1维数据,video通过 patch处理后形状改变为(L,D) L = T ∗ H ∗ W L=T*H*W L=THW即图中的THW,Self-attention计算公式主要是 Q K T ∗ V QK^T*V QKTV,假设 S h a p e Q = ( L Q , D ) , S h a p e K = ( L K , D ) S h a p e V = ( L V , D ) Shape_Q=(L_Q,D),Shape_K=(L_K,D)Shape_V=(L_V,D) ShapeQ=(LQ,D),ShapeK=(LK,D)ShapeV=(LV,D),则
S h a p e ( Q K T ) = ( L Q , D ) ( D , L K ) = ( L Q , D ) Shape(QK^T)=(L_Q,D)(D,L_K)=(L_Q,D) Shape(QKT)=(LQ,D)(D,LK)=(LQ,D)
S h a p e ( Q K T ∗ V ) = ( L Q , L K ) ∗ ( L V , D ) = ( L Q , D ) Shape(QK^T*V)=(L_Q,L_K)*(L_V,D)=(L_Q,D) Shape(QKTV)=(LQ,LK)(LV,D)=(LQ,D)

为了使公式成立,必须保证 L K = L V L_K=L_V LK=LV,即图中THW,所以为了降低空间分辨率,只需要改变Q向量的序列长度,所以对Q向量进行pooling操作即可,同时实验证明K,V向量pooling会提高指标,所以对K,V向量也进行了pooling操作,但是不会影响空间分辨率的大小,为了保证res connection成立,需要对输入X同样进行和Q向量一样的pooling操作
pooling操作又分为max/ average/ conv等,论文实验部分对不同的pooling操作进行了消融实验,最终确定为333 核的conv pooling操作。
如何提高通道数?
提高通道数就是通过简单的全连接层对向量维度D进行映射即可
代码

# pool通常是MaxPool3d或AvgPool3d
def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None):
    if pool is None:
        return tensor, thw_shape
    tensor_dim = tensor.ndim
    if tensor_dim == 4:
        pass
    elif tensor_dim == 3:
        tensor = tensor.unsqueeze(1)
    else:
        raise NotImplementedError(f"Unsupported input dimension {tensor.shape}")
 
    if has_cls_embed:
        cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :]
 
    B, N, L, C = tensor.shape
    T, H, W = thw_shape
    tensor = (tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous())
    # 执行pooling操作
    tensor = pool(tensor)
 
    thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]]
    L_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4]
    tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3)
    if has_cls_embed:
        tensor = torch.cat((cls_tok, tensor), dim=2)
    if norm is not None:
        tensor = norm(tensor)
    # Assert tensor_dim in [3, 4]
    if tensor_dim == 4:
        pass
    else:  #  tensor_dim == 3:
        tensor = tensor.squeeze(1)
    return tensor, thw_shape

2.多尺度变换器网络(Multiscale Transformer Networks)

基于多头集中注意力(MHPA),本文创造了专门使用MHPA和MLP层进行视觉表征学习的多尺度变换器模型。在此之前,了解一下ViT模型。

2.1ViT

这里需要注意的是模型基于纯Transformer架构的,所以采用了Patch操作,详情参考ViT,所以图中的1,2,3,4是patch的大小,随着模型深入,patch是变大的,但是空间分辨(Patch分辨率)是降低的。

MViT

逐步增加信道维度,同时降低整个网络的时空分辨率(即序列长度)。MViT在早期层中具有精细的时空分辨率和低信道维度,而在后期层中,变为粗略的时空分辨率和高信道维度。MViT如表2所示,
MViT-code_第3张图片
需要注意之前提到了需要对数据进行patch操作,通过卷积实现(cube1),但是视频信号还有一个维度T,如图3所示,参数 s T sT sT代表cube1中计算卷积时对T维度的步长,是一个超参数,后续实验中出现的例如MViT-B 16*4 指的是输入16帧视频帧, s T sT sT取值4。
N ∗ N_* N代表的是每个stage使用的Transformer个数,MHPA(D)代表的是其处理向量的维度为D,MLP(4D)表示Transformer block中全连接层隐藏单元数为4D,即输入维度的四倍。
Scale stages
尺度阶段定义为一组N个变换器块,在相同的尺度上跨信道和时空维度以相同的分辨率运行。在阶段转换时,信道维度上采样,而序列的长度下采样。
每个stage都应用了若干个Transformer blocks,图2所示的是每个Transformer block都采用了pooling的操作,所以为了保证每个stage中只对空间分辨率进行一次下采样,只在每个stage的第一个Transformer block对向量Q进行 P o o l Q = ( 1 , 2 , 2 ) Pool_Q=(1,2,2) PoolQ=(1,2,2)的操作,通stagetage其余Transformer block的向量Q进行 P o o l Q = ( 1 , 1 , 1 ) Pool_Q=(1,1,1) PoolQ=(1,1,1)的操作。对K,V的pooling操作不影响空间分辨率,所以论文中在同一个stage的所有Transformer block的K,V都进行了同样的pooling操作,即 P o o l K = P o o l V = ( 1 , 8 , 8 ) Pool_K=Pool_V=(1,8,8) PoolK=PoolV=(1,8,8),
随着stage变深衰减, s c a l e 3 P o o l K = P o o l V = ( 1 , 8 , 8 ) , s c a l e 4 P o o l K = P o o l V = ( 1 , 4 , 4 ) scale3 Pool_K=Pool_V=(1,8,8),scale4 Pool_K=Pool_V=(1,4,4) scale3PoolK=PoolV=(1,8,8),scale4PoolK=PoolV=(1,4,4)
之前说过通过全连接层进行通道数增加的操作,图3中并未显示的展示,其实在两个stages之间存在一个过渡操作,即每个stage的output sizes需要通过一个全连接层将维度D进行映射,只进行通道数增加的操作,然后送入下一个stage进行计算。
MViT更改Transformer结构之后的multi-head的个数如何确定?
论文中维度D对应一个head,即图中scale5使用的Transformer blocks的multi-head个数等于8,scale4中head个数等于4,以此类推

图4(a)是ViT的框架,MHA就是普通的multi-head-attention,可以发现,在ViT中是没有分层结构的,输出和输入形状是一样的,MViT采用MHPA引入了分层结构,提出了两个不同大小的模型MViT-B/ MViT-S,值得注意的是,两个模型的体量都比较小,不到7G的显存就可以运行MViT-B,对在校学生来说非常友好的,显存足够的情况下有很大的改进空间。

你可能感兴趣的:(读论文,深度学习,python,机器学习)