transformer系列5---transformer显存占用分析

Transformer显存占用分析

  • 1 影响因素概述
  • 2 前向计算临时Tensor显存占用
    • 2.1 self-attention显存占用
    • 2.2 MLP显存占用
  • 3 梯度和优化器显存占用
    • 3.1 模型训练过程两者显存占用
    • 3.2 模型推理过程两者显存占用

1 影响因素概述

  1. 模型训练框架:例如pytorch框架的cuda context会占用大约几百MB显存,与版本有关;
  2. 模型参数大小,比如7B的模型以FP16格式要占用14GB显存;
  3. 前向计算过程中产生的临时Tensor:这部分Tensor需要被临时保存,以便在反向传播计算梯度时使用
  4. 反向传播计算得到的梯度:
  5. 优化器状态:全量微调的情况下,梯度与参数一样大,普通SGD没有动量,一阶动量优化器的自身参数大小与模型大小一样,比如momentum-SGD,二阶动量优化器一般为模型大小的两倍,比如Adam, transformer系列的大模型最常用的是Adam优化器

2 前向计算临时Tensor显存占用

2.1 self-attention显存占用

这部分Tensor的大小和模型的每一层结构形状有关(必须根据具体模型的每层形状来计算)也和具体的batch_size大小以及输入数据input_data的大小有关。

  1. 输入矩阵I:首先计算 Q = I ∗ W q Q =I * W^{q} Q=IWq K = I ∗ W k K = I * W^{k} K=IWk V = I ∗ W v V = I * W^{v} V=IWv,输入I是临时Tensor,假设输入I的形状为 [b, s, d],元素个数为 bsd,占用显存大小为2bytes*bsd=2bsd bytes.
  2. Q K T QK^{T} QKT:Q和K是临时Tensor,假设形状为 [b, s, d],元素个数为 bsd,占用显存大小为22bytesbsd=4bsd bytes。
  3. softmax: A = Q K T A=QK^{T} A=QKT,输入形状[b, h, s, d] × [b, h, s, d],A矩阵输出形状为 [b, h, s, s],h是头个数。保存A矩阵占用的显存大小为=2bytes* b h s 2 bhs^{2} bhs2= 2 b h s 2 2bhs^{2} 2bhs2 bytes。
  4. dropout:需要保存一个mask矩阵,mask矩阵的形状与A相同,mask矩阵的元素为0或1,用1个byte表示,占用显存大小为 b h s 2 bhs^{2} bhs2 bytes。
  5. score* V加权:score矩阵的形状与A相同,占用显存大小为 2 b h s 2 2bhs^{2} 2bhs2 bytes。V矩阵形状[b, s, d],占用显存大小为2bytes*bsd=2bsd bytes。该步骤占用显存大小为 2 b h s 2 + 2 b s d 2bhs^{2}+2bsd 2bhs2+2bsd bytes。
  6. W O W^{O} WO输出映射:需要临时保存输入矩阵,形状[b, s, d],占用显存大小为2bytes*bsd=2bsd bytes。
  7. dropout:需要保存一个mask矩阵,mask矩阵的形状为上一步输出形状[b, s, d],mask矩阵的元素为0或1,用1个byte表示,占用显存大小为1bytes*bsd=bsd bytes。
    综上步骤,self-attention块的占用显存大小为2bsd+4bsd+ 2 b h s 2 2bhs^{2} 2bhs2+ 2 b h s 2 2bhs^{2} 2bhs2+ 2 b h s 2 + 2 b s d 2bhs^{2}+2bsd 2bhs2+2bsd+2bsd+2bsd=11bsd+ 5 b h s 2 5bhs^{2} 5bhs2

2.2 MLP显存占用

  1. 第一个线性层需要保存其输入,输入形状为[b, s, d],占用显存大小为 2bytes*bsd=2bsd bytes。
  2. 激活函数需要保存其输入,为第一步的输出形状为[b, s, 4d],占用显存大小为2bytes*4bsd=8bsd bytes。
  3. 第二个线性层需要保存其输入,输入形状为[b, s, 4d],占用显存大小为2bytes*4bsd=8bsd bytes。
  4. 最后有一个dropout操作,需要保存mask矩阵,形状是上一步的输出形状[b, s, d],mask矩阵的元素为0或1,用1个byte表示,占用显存大小为1bytes*bsd=bsd bytes。

综上步骤,MLP的占用显存大小为2bsd+8bsd+8bsd+bsd=19bsd.

3 梯度和优化器显存占用

3.1 模型训练过程两者显存占用

参数占用显存 = 参数数目 × n
n = 2 : float16
n = 4 : float32
n = 8 : double64
其中,float32是最常用的类型,n是数据类型占用的bytes。
训练过程通常为模型参数前向传播,反向传播计算梯度,优化器更新,以Adam优化器为例分析,假如模型参数量为P:

  1. 混合精度训练:
    1)使用float16的模型参数进行前向传递和反向传播,计算得到float16的梯度;
    2)在优化器更新模型参数时,使用float32的优化器状态、float32的梯度、float32的模型参数来更新模型参数。
    3)对于每个可训练模型参数,模型参数在步骤1)和步骤2)分别是2bytes,4bytes;梯度在步骤1)和步骤2)分别是分别是2bytes,4bytes;优化器状态是2* 模型大小=2*4bytes=8bytes。

每个参数占用(2+4)+(2+4)+8 = 20bytes。模型参数量M时总计20P bytes。

  1. 普通训练:
    上述步骤1)2)均使用float32类型。对于每个可训练模型参数,模型参数在步骤1)和步骤2)分别是4bytes,4bytes;梯度在步骤1)和步骤2)分别是分别是4bytes,4bytes;优化器状态是2* 模型大小=2*4bytes=8bytes。

每个参数占用(4+4)+(4+4)+8 = 24bytes,模型参数量M时总计24P bytes。

3.2 模型推理过程两者显存占用

推理占用显存主要是模型参数,假如模型参数量为P,使用float16来进行推理,推理阶段模型参数占用的显存约2P bytes,使用float32来进行推理,推理阶段模型参数占用的显存约 4P bytes。

参考文章:https://zhuanlan.zhihu.com/p/624740065?utm_id=0

你可能感兴趣的:(transformer,transformer,深度学习,人工智能)