【论文笔记2】Long-Short Transformer: Efficient Transformers for Language and Vision

论文原文:https://arxiv.org/abs/2107.02192
论文笔记:百度网盘提取码:nzsi

1. Summary Contributions:

  • (1)提出了一种长短时Transformer模型:Long-Short Transformer (Transformer-LS):
    • Short:利用滑动窗口获取短序列(局部)attention
    • Long:基于动态投影获取长序列(全局)attention
  • (2)在Long-Transformer中提出动态投影方法
  • (3)提出 DualLN(双向归一化)来解决不同term下维度不一致的问题
  • (4)可以很好加入视觉Transformer模型中取得好效果
    【论文笔记2】Long-Short Transformer: Efficient Transformers for Language and Vision_第1张图片

2. 基于滑动窗口计算Short-term Attention(双向模型)

  • Step1:图中每一行表示一个序列,长度n为8,将序列复制8次得到矩阵;(1维)窗口宽度为2,节点特征维度为d=3。

  • Step2:序列左右各有一个大小为w/2的padding,用长度为w的窗口将序列分为不相交的多等分,计算attention
    由于每个窗口包含w个token,但关注了2*w个token,因此:

    • input:n × d
    • output_attention:2w × d
      【论文笔记2】Long-Short Transformer: Efficient Transformers for Language and Vision_第2张图片
  • 这里每一个token的output_attention相当于其关于自身以及上下文信息的聚合,所以是short-term(局部)

3. 基于动态投影计算Long-term Attention(自回归模型)

  • 直观理解:利用低秩矩阵将节点数量维度 N 投影至更低维 r(r根据输入序列长度决定)
  • Step1:根据K构造投影矩阵Pi,从而Pi维度变为 n × r,其中r< 在这里插入图片描述
  • Step2:通过以下两公式,将K和V的维度从n×d变为 r×d,
    在这里插入图片描述
    扩展到多头即为如下公式,其中Q维度还是n×d,所以最终输出序列个数还是与原始相同
    【论文笔记2】Long-Short Transformer: Efficient Transformers for Language and Vision_第3张图片
  • Step3:将Long-term Attention用于自回归模型中
    【论文笔记2】Long-Short Transformer: Efficient Transformers for Language and Vision_第4张图片
    • (1):将Input划分为多个等长的L(相当于w)
    • (2):对于每个内部的patch,关注前面一个L和当前窗口内左边的token
    • (3)将该Attention计算方式放入自回归模型中:
      【论文笔记2】Long-Short Transformer: Efficient Transformers for Language and Vision_第5张图片
      • (4)并行计算每个L的attention,后将结果拼接起来:
        在这里插入图片描述
        最终输出的output_shape 为 r × dk

4. DualLN:双向归一化

【论文笔记2】Long-Short Transformer: Efficient Transformers for Language and Vision_第6张图片

  • 问题:如果直接拼接后进行LayerNorm:标准化后均值为0,而由于使用多头注意力机制进行加权平均的时候,会减小均值为0的向量的方差,从而减小该向量的范数。而对于Long和Short而言,其节点维度不同(Long>>Short),故Long这边的范数相对Short而言会较小,导致梯度更小,阻碍模型训练
  • 解决:分别进行LN后再拼接,公式如下
  • 在这里插入图片描述

5. 小结

  • (1)提出了short term的双向模型和Long-term的自回归模型,对比如右图:其在长序列数据处理中表现较好,主要是因为经过动态投影降低了attention的计算复杂度,所以序列越长相较而言越占优势。
  • (2)通过动态投影降低了attention的计算复杂度。
    传统的attention计算复杂度为O(n2),改进后由于将输入节点维度投影至r,而r是一个超参数,将使得计算复杂度降为了O(r*n)
  • (3)针对序列化数据,LSTM将上下文信息通过从网络浅层传至网络深层体现序列化,而LS-Transformer把序列化信息体现在attention的计算过程上。(有待考虑)

你可能感兴趣的:(深度学习,python)