分析transformer模型的参数量、计算量、中间激活、KV cache

难得一遇的好文,转载自https://zhuanlan.zhihu.com/p/624740065,先做个简单总结:

训练时的参数量由以下模型参数前向的中间激活后向的梯度优化器参数组成:

模型参数

假设Transformer的hidden_size是h,那么总的一层transformer层的参数量为 12 h 2 + 13 h 12h^2+13h 12h2+13h ,l层的transformer计算量就是 l ∗ ( 12 h 2 + 13 h ) l*(12h^2+13h) l(12h2+13h):

  • 多头注意力四个参数矩阵Q、K、V、O和他们的bias对应 4 h 2 + 4 h 4h^2+4h 4h2+4h
  • position wise feedforward对应 4 h 2 + 4 h + 4 h 2 + h 4h^2+4h+4h^2+h 4h2+4h+4h2+h
  • 多头注意力和position wise feedforward各有一个LayerNorm对应2个可训练模型缩放参数beta和平移参数gamma参数是4h

前向计算过程中产生的中间激活

前向计算过程中产生的中间激活,中间激活值与输入数据的大小(批次大小b和序列长度 l)是成正相关的,随着批次大小b和序列长度l的增大,中间激活占用的显存会同步增大。当我们训练神经网络遇到显存不足OOM(Out Of Memory)问题时,通常会尝试减小批次大小来避免显存不足的问题,这种方式减少的其实是中间激活占用的显存

后向传递计算得到的梯度、优化器状态

在一次训练迭代中,每个可训练模型参数都会对应1个梯度,并对应2个优化器状态(Adam优化器梯度的一阶动量和二阶动量)。设模型参数量为 Φ \Phi Φ,那么梯度的元素数量为 Φ \Phi Φ,AdamW优化器的元素数量为 2 Φ 2\Phi 。float16数据类型的元素占2个bytes,float32数据类型的元素占4个bytes。在混合精度训练中,会使用float16的模型参数进行前向传递和后向传递,计算得到float16的梯度;在优化器更新模型参数时,会使用float32的优化器状态、float32的梯度、float32的模型参数来更新模型参数。因此,对于每个可训练模型参数,占用了 ( 2 + 4 ) + ( 2 + 4 ) + ( 4 + 4 ) = 20 b y t e s (2+4)+(2+4)+(4+4)=20bytes (2+4)+(2+4)+(4+4)=20bytes。使用AdamW优化器和混合精度训练来训练参数量为 Φ \Phi Φ的大模型,模型参数、梯度和优化器状态占用的显存大小为 20 Φ 20\Phi 20Φ

推理时的参数量少了梯度、优化器状态、中间激活,但多了kvcache

假设输入序列的长度为s,输出序列的长度为n,以float16来保存KV cache,那么KV cache的峰值显存占用大小为 b ∗ ( s + n ) h ∗ l ∗ 2 ∗ 2 = 4 l h ( s + n ) b*(s+n)h*l*2*2 =4lh(s +n) b(s+n)hl22=4lh(s+n)。这里第一个2表示K/V cache,第个2表示float16占2个bytes。
以GPT3为例,对比KV cache与模型参数占用显存的大小。GPT3模型占用显存大小为350GB。假设批次大小b=64,输入席列长度 =512,输出序列长度n =32,则KV cache占用显存大约为 46 l h ( s + n ) 46lh(s + n) 46lh(s+n)= 164,282,499,072bytes约等于164GB,大约是模型参数显存的0.5倍

训练时的计算量

FLOPs,floating point operations,表示浮点数运算次数,衡量了计算量的大小

对于 A ∈ R m × n A \in R^{m \times n} ARm×n, B ∈ R n × p B \in R^{n \times p} BRn×p,矩阵A✖️B浮点运算需要mnp次乘法和mnp次加法,因此FLOPs为2mnp。这个刚好是计算复杂度*2的关系,因此我们下面只估计乘法的计算次数,最后算FlOPs乘2就好了

模型计算复杂度

假设Transformer的hidden_size是h,序列长度是s,batch大小是b,那么总的一层transformer层的参数量为 12 b s h 2 + 2 b s 2 h 12bsh^2+2bs^2h 12bsh2+2bs2h ,l层的transformer计算量就是 l ∗ ( 12 b s h 2 + 2 b s 2 h ) l*(12bsh^2+2bs^2h) l(12bsh2+2bs2h):

  • 多头注意力四个参数矩阵Q、K、V、O对应 4 b s h 2 4bsh^2 4bsh2(忽略他们的bias)
  • QK、KV的复杂度是 2 b s 2 h 2bs^2h 2bs2h [ b , s , h ] ∗ [ b , h , s ] = [ b , s , s ] [b,s,h]*[b,h,s]=[b,s,s] [b,s,h][b,h,s]=[b,s,s] [ b , s , s ] ∗ [ b , s , h ] = [ b , s , h ] [b,s,s]*[b,s,h]=[b,s,h] [b,s,s][b,s,h]=[b,s,h],这俩都是 b s 2 h bs^2h bs2h
  • position wise feedforward对应 4 b s h 2 + 4 b s h 2 4bsh^2+4bsh^2 4bsh2+4bsh2

计算量与参数量的关联

当隐藏维度h比较大,且远大于序列长度s时,我们可以忽略一次项,计算量可以近似为 ( 12 b s h 2 + 2 b s 2 h ) ∗ 2 12 h 2 + 13 h ≈ 2 \frac{(12bsh^2+2bs^2h)*2}{12h^2+13h}\approx2 12h2+13h(12bsh2+2bs2h)22 。我们可以近似认为:在一次前向传递中,对于每个token,每个模型参数,需要进行2次浮点数运算,即一次乘法法运算和一次加法运算。

一次训练迭代包含了前向传递和后向传递,后向传递的计算量是前向传递的2倍(loss.backward()和optimizer.step()两步)。因此,前向传递 + 后向传递的系数 = 1+2 =3 。一次训练迭代中,对于每个token,每个模型参数,需要进行 2*3=6次浮点数运算。

训练时间估计

模型参数量和训练总tokens数决定了训练transformer模型需要的计算量。给定硬件GPU类型的情况下,可以估计所需要的训练时间。给定计算量,训练时间(也就是GPU算完这么多flops的计算时间)不仅跟GPU类型有关,还与GPU利用率有关。计算端到端训练的GPU利用率时,不仅要考虑前向传递和后向传递的计算时间,还要考虑CPU加载数据、优化器更新、多卡通信和记录日志的时间。一般来讲,GPU利用率一般在0.3~0.55之间。

上文讲到一次前向传递中,对于每个token,每个模型参数,进行2次浮点数计算。使用激活重计算技术来减少中间激活显存(下文会详细介绍)需要进行一次额外的前向传递,因此前向传递 + 后向传递 + 激活重计算的系数=1+2+1=4。使用激活重计算的一次训练迭代中,对于每个token,每个模型参数,需要进行2*4=8次浮点数运算。在给定训练tokens数、硬件环境配置的情况下,训练transformer模型的计算时间为:
分析transformer模型的参数量、计算量、中间激活、KV cache_第1张图片


以下是转载截屏:
分析transformer模型的参数量、计算量、中间激活、KV cache_第2张图片


你可能感兴趣的:(transformer,深度学习,人工智能)