transformer系列4---transformer结构计算量统计

transformer计算量

  • 1 术语解释
  • 2 矩阵相乘FLOPs
  • 3 Transformer的FLOPs估计
    • 3.1 MultiHeadAttention
      • 3.1.1 Q,K,V计算
      • 3.1.2 attention计算
      • 3.1.3 MultiHeadAttention输出线性映射
      • 3.1.4 MultiHeadAttention总计算量
    • 3.2 MLP
    • 3.3 projection输出
    • 3.3 计算量累计

1 术语解释

  1. FLOPs(Floating Point Operations):浮点运算次数,用来衡量模型计算复杂度,常用来做神经网络模型速度的间接衡量标准。(但该指标与实际模型速度并不一定正相关)

  2. MACs(Multiply–Accumulate Operations):乘加累积操作数,1个MACs包含一个乘法操作与一个加法操作,大约包含2FLOPs。通常MACs与FLOPs存在一个2倍的关系。

2 矩阵相乘FLOPs

  1. 对于 A 1 × n , B n × 1 A^{1×n},B^{n×1} A1×n,Bn×1两个矩阵相乘,计算AB需要进行n次乘法和n次加法,共计2n次浮点数运算 ,即2n的FLOPs,
  2. 对于 A m × n , B n × p A^{m×n},B^{n×p} Am×n,Bn×p两个矩阵相乘,计算AB需要进行2nmp次浮点数运算。

3 Transformer的FLOPs估计

假设Transformer的输入每个词向量维度d_model(d) ,词表大小为vocab_size(v),输入句子最大长度为src_max_len(s),batchsize为 batch(b),head头数为head(h)。
对于输入部分,将输入句子分词并且词嵌入步骤没有计算量,位置编码也没有计算量,因此,计算量主要集中在MultiHeadAttention、MLP、以及最后的投影计算。

3.1 MultiHeadAttention

3.1.1 Q,K,V计算

  1. 1个矩阵计算量:对于输入I,首先计算 Q = I ∗ W q Q =I * W^{q} Q=IWq K = I ∗ W k K = I * W^{k} K=IWk V = I ∗ W v V = I * W^{v} V=IWv,假设输入I的形状为 [b, s, d],1个矩阵乘法的输入和输出形状为[b, s, d] × [d, d] = [b, s, d],计算量为 2 b s d 2 2bsd^{2} 2bsd2
  2. 3个矩阵计算量: 6 b s d 2 6bsd^{2} 6bsd2

3.1.2 attention计算

  1. Q K T QK^{T} QKT

  矩阵乘法的输入形状[b, h, s, d] × [b, h, s, d],输出形状为 [b, h, s, s],h维度是concat,没有计算量,因此该步骤的计算量为 2 b s 2 d 2bs^{2}d 2bs2d

  1. score*V加权
    输入形状为[b, h, s, s] × [b, h, s, d],输出形状为[b, h, s, d], h维度是concat,没有计算量,因此该步骤的计算量为 2 b s 2 d 2bs^{2}d 2bs2d

3.1.3 MultiHeadAttention输出线性映射

所有head都concat,输入形状为[b, s, d] × [d, d] ( W O ) (W^{O}) (WO),输出形状为[b, s, d],计算量 2 b s d 2 2bsd^{2} 2bsd2

3.1.4 MultiHeadAttention总计算量

MultiHeadAttention总计算量为上面三部分之和, 2 b s 2 d 2bs^{2}d 2bs2d+ 2 b s 2 d 2bs^{2}d 2bs2d+ 2 b s d 2 2bsd^{2} 2bsd2= 4 b s 2 d 4bs^{2}d 4bs2d+ 2 b s d 2 2bsd^{2} 2bsd2

3.2 MLP

MLP内包含2个线性层:

  1. 第一个线性层,矩阵乘法输入形状为[b, s, d] × [d, 4d],输出形状为[b, s, 4d],计算量 8 b s d 2 8bsd^{2} 8bsd2
  2. 第二个线性层,矩阵乘法输入形状为[b, s, 4d] × [4d, d],输出形状为[b, s, d],计算量 8 b s d 2 8bsd^{2} 8bsd2

MLP总计算量为 8 b s d 2 8bsd^{2} 8bsd2+ 8 b s d 2 8bsd^{2} 8bsd2= 16 b s d 2 16bsd^{2} 16bsd2

3.3 projection输出

logits的计算,将隐藏向量映射为词表大小。矩阵乘法输入形状为[b, s, d] × [d, v],输出形状为[b, s, v],计算量 2 b s d v 2bsdv 2bsdv

3.3 计算量累计

  1. Transformer的encoder,包含1个MultiHeadAttention,1个MLP
  2. Transformer的decoder,包含2个MultiHeadAttention,1个MLP
  3. Transformer的输出为1个projection
    将上面3部分累加,计算量为 4 b s 2 d 4bs^{2}d 4bs2d+ 2 b s d 2 2bsd^{2} 2bsd2+ 16 b s d 2 16bsd^{2} 16bsd2+2*( 4 b s 2 d 4bs^{2}d 4bs2d+ 2 b s d 2 2bsd^{2} 2bsd2)+ 16 b s d 2 16bsd^{2} 16bsd2+ 2 b s d v 2bsdv 2bsdv= 12 b s 2 d 12bs^{2}d 12bs2d+ 36 b s d 2 36bsd^{2} 36bsd2+ 2 b s d v 2bsdv 2bsdv

你可能感兴趣的:(transformer,transformer,深度学习,人工智能)