transformer系列之空间复杂度

一、加载阶段

该阶段是指将模型加载进GPU的状态; 该阶段仅仅需要计算模型的参数量足以;
transformer模型由 l l l个相同的层组成,每个层分为两部分:self-attention块和MLP块,如图所示:
mistral-7b
transformer系列之空间复杂度_第1张图片
llama-7b
transformer系列之空间复杂度_第2张图片

self-attention块的模型参数有 Q K V QKV QKV的权重矩阵 W Q , W K , W V W_Q,W_K,W_V WQ,WK,WV和偏置,输出权重矩阵 W O W_O WO和偏置,
4个权重矩阵的形状为 [ h , h ] [h,h] [h,h],4个偏置的形状为 [ 1 , h ] [1,h] [1,h],self- attention块的参数量为 4 h 2 + 4 h 4h^2+4h 4h2+4h
上述结构是学术上经常提到的结构,但各个开源大模型的结构都会有些变化,且没有偏置项。

MLP块由2个线性层组成,第一个线性层 W u p 与 W b ,维度分别为 [ h , 4 h ] , [ 1 , 4 h ] W_{up}与W_b,维度分别为[h,4h],[1,4h] WupWb,维度分别为[h,4h][1,4h]; 第二个线性层 W d o w n W_{down} Wdown W b W_b Wb维度分别为 [ 4 h , h ] , [ 1 , h ] [4h,h],[1,h] [4h,h],[1,h];
参数量为 8 ∗ h 2 + 5 ∗ h 8*h^2 + 5*h 8h2+5h

self-attention块和MLP块各有一个layer normalization,包含了2个可训练模型参数:缩放参数 γ \gamma γ和平移参数 β \beta β,形状都是 [ h ] [h] [h],2个layer normalization的参数量为 4 h 4h 4h

总的,每个transformer层的参数量为 12 h 2 + 13 h 12h^2+13h 12h2+13h;
除此之外,词嵌入矩阵的参数量也较多,词向量维度通常等于隐藏层维度 h h h, 词嵌入矩阵的参数量为 V h Vh Vh 最后的输出层的权重矩阵通常与词嵌入矩阵是参数共享的;

关于位置编码,如果采用可训练式的位置编码,会有一些可训练模型参数,数量比较少。如果采用相对位置编码,例如RoPE和ALiBi,则不包含可训练的模型参数。我们忽略这部分参数。

综上, l l l层transformer模型的可训练模型参数量为 l ∗ ( 12 h 2 + 13 h ) + V h l*(12h^2+13h)+Vh l(12h2+13h)+Vh; 当隐藏维度 h h h较大时,可以忽略一次项,而且当前开源大模型的偏转项都是为False,所以模型参数量可以近似为 12 l h 2 12lh^2 12lh2
llama系列模型:
transformer系列之空间复杂度_第3张图片
来源:https://zhuanlan.zhihu.com/p/624740065
但有个疑问,32的头的数量为什么没有计算进去,假设是4096维度分成32个头,每个头的维度为128呢?
transformer系列之空间复杂度_第4张图片

二、推理阶段

在神经网络的推理阶段,没有优化器状态和梯度,也不需要保存中间激活。少了梯度、优化器状态、中间激活,模型推理阶段占用的显存要远小于训练阶段。模型推理阶段,占用显存的大头主要是模型参数,如果使用float16来进行推理,推理阶段模型参数占用的显存大概是 2 Φ 2\Phi bytes。如果使用KV cache来加速推理过程,KV cache也需要占用显存。
transformer系列之空间复杂度_第5张图片
如果输入序列长度+输出序列长度扩大100倍,那么显存占用量会是模型参数的50倍,就会变得非常恐怖。

三、训练阶段

由于当前模型训练大都采用混合精度训练,所以也以混合精度训练为例;
训练所需显存主要分成四个部分,假设模型参数量为 Φ \Phi Φ

  1. 模型参数
    用fp16保存,故占用字节数为2 Φ \Phi Φ
  2. 后向传递的梯度与优化器状态
    transformer系列之空间复杂度_第6张图片
    图片来自于:https://zhuanlan.zhihu.com/p/608634079
    该过程占用18 Φ \Phi Φ,包括计算得到的梯度2 Φ \Phi Φ,优化器一阶与二阶为8 Φ \Phi Φ,模型权重副本及梯度副本为8 Φ \Phi Φ
  3. 前向计算过程中产生的中间激活
    中间激活的概念:前向传递过程中计算得到的,并在后向传递过程中需要用到的所有张量。这里的激活不包含模型参数和优化器状态,但包含了dropout操作需要用到的mask矩阵。
    为什么需要计算中间激活值:因为需要保存中间激活以便在后向传递计算梯度时使用
    为什么中间激活是显存消耗大户呢:因为在一次训练迭代中,模型参数(或梯度)占用的显存大小只与模型参数量和参数数据类型有关,与输入数据的大小是没有关系的。优化器状态占用的显存大小也是一样,与优化器类型有关,与模型参数量有关,但与输入数据的大小无关。而中间激活值与输入数据的大小(批次大小 和序列长度 )是成正相关的,随着批次大小和序列长度 的增大,中间激活占用的显存会同步增大。当我们训练神经网络遇到显存不足OOM(Out Of Memory)问题时,通常会尝试减小批次大小来避免显存不足的问题,这种方式减少的其实是中间激活占用的显存,而不是模型参数、梯度和优化器的显存。
    transformer系列之空间复杂度_第7张图片
    https://zhuanlan.zhihu.com/p/624740065
    中间激活使用的显存量

你可能感兴趣的:(人工智能)