按照定义一个a×b的矩阵乘以一个b×c的矩阵要做abc次乘法,所以abc就是两个矩阵相乘的复杂度了,这是我们估算Transformer复杂度的依据
设n为序列长度,d为head_size(base版是64),h为head的数目(base版是12),那么hd就是我们通常说的“hidden_size”(base版是768)。
对于SA来说:
Q,K,V的投影变换,即n×hd的矩阵乘以hd×hd的矩阵做3次,因此计算量是3n(hd)2;
h个Attention头的运算,每个头先是n×d的Q与d×n的KT相乘得到n×n的Attention矩阵(softmax和归一化的计算量暂且忽略),然后n×n的矩阵与n×d的V相乘得到n×d的矩阵,这两步的计算量都是n2d,所以总计算量是h(n2d+n2d);
输出投影变换,也是n×hd的矩阵乘以hd×hd的矩阵,计算量是n(hd)2
所以,SA的总计算量是
3n(hd)2+h(n2d+n2d)+n(hd)2=4nh2d2+2n2hd
FFN就是两个全连接层,也就是两个矩阵变换(激活函数的计算量也忽略不计),一般的参数设置是:第一层是n×hd的矩阵乘以hd×4hd的矩阵,第二层就是n×4hd的矩阵乘以4hd×hd的矩阵。所以总计算量是
n×hd×4hd+n×4hd×hd=8nh2d2
4nh2d2+2n2hd > 8nh2d2 ==> n>2hd
对于base版来说,这意味着n>1536!也就是说,只有当序列长度超过1536时,SA的计算量才大于FFN,在这之前,都是线性复杂度的FFN占主导
4nh2d2+2n2hd + 8nh2d2 = 12nh2d2+2n2hd
它是关于n的一次项和二次项的求和,当n足够大时,复杂度自然是(n2),然而二次项占主导的条件是
2n2hd>12nh2d2==> n>6hd
对于base版来说,这意味着n>4608!也就是说,当序列长度接近5000时,Transformer的复杂度才真正体现出二次性!
对于base版来说,当序列长度不超过1536时,Transformer的复杂度都是近乎线性的;
当序列长度超过1536时,Transformer的计算量逐渐以Attention为主,复杂度慢慢趋于二次方,直到长度超过4608,才真正以二次项为主
这些改进工作所关心的序列长度主要都是以千为单位的,有明显计算效率提升的序列长度基本上都要好几千;当然,我们前面的讨论主要针对的还是时间复杂度,对于空间复杂度,也就是显存占用量,降低的幅度一般要比时间复杂度提升的幅度的要大,但总体而言都是长序列才有价值。