大模型训练显存优化推理加速方案

当前的深度学习框架大都采用的都是fp32来进行权重参数的存储,比如Python float的类型为双精度浮点数fp64,pytorch Tensor的默认类型为单精度浮点数fp32。随着模型越来越大,加速训练模型的需求就产生了。在深度学习模型中使用fp32主要存在几个问题,第一模型尺寸大,训练的时候对显卡的显存要求高;第二模型训练速度慢;第三模型推理速度慢。其解决方案就是使用低精度计算对模型进行优化。本文主要讲解几种优化显存存储的方法。

1. fp32、fp16、bf16混合精度训练

  • FP32 是单精度浮点数,1位符号位,8位指数,23位表示小数,总共32位
  • BF16 是对FP32单精度浮点数截断数据,即用8bit 表示指数,7bit 表示小数
  • FP16 半精度浮点数,用5bit 表示指数,10bit 表示小数;

    与32位相比,采用BF16/FP16吞吐量可以翻倍,内存需求可以减半。但是这两者精度上差异不一样,BF16 可表示的整数范围更广泛,但是尾数精度较小;FP16 表示整数范围较小,但是尾数精度较高。

1.1 混合精度训练

直接使用半精度进行计算会导致的两个问题的处理:舍入误差(Rounding Error)和溢出错误(Grad Overflow / Underflow)

  • 舍入误差
    float16 的最大舍入误差约为   2 − 10 ~2 ^{-10}  210,比 float32 的最大舍入误差   2 − 23 ~2 ^{-23}  223 要大不少。 对足够小的浮点数执行的任何操作都会将该值四舍五入到零。在反向传播中很多梯度更新值都非常小,但不为零,在反向传播中舍入误差累积可以把这些数字变成0或者nan, 这会导致不准确的梯度更新,影响网络的收敛

  • 溢出错误
    由于 float16 的有效的动态范围(正数部分,负数部分与正数对应)约为 5.96 × 1 0 − 8 ∼ 6.55 × 10 4 5.96\times10^{-8} \sim 6.55\times10{^4} 5.96×1086.55×104,比单精度的 float32 的动态范围 1.4 × 1 0 − 45 ∼ 1.7 × 1 0 38 1.4\times10^{-45} \sim 1.7 \times10^{38} 1.4×10451.7×1038要狭窄很多,精度下降会导致得到的值大于或者小于fp16的有效动态范围,也就是上溢出或者下溢出。在深度学习中,由于激活函数的的梯度往往要比权重梯度小,更易出现下溢出的情况

针对以上两种情况的解决方法是混合精度训练(Mixed Precision)和损失缩放(Loss Scaling)

  • 混合精度训练
    混合精度训练是一种通过在FP16上执行尽可能多的操作来大幅度减少神经网络训练时间的技术,在像线性层或是卷积操作上,FP16运算较快,但像Reduction运算又需要 FP32的动态范围。通过混合精度训练的方式,便可以在部分运算操作使用FP16,另一部分则使用 FP32,混合精度功能会尝试为每个运算使用相匹配的数据类型,在内存中用FP16做储存和乘法从而加速计算,用FP32做累加避免舍入误差。这样在权重更新的时候就不会出现舍入误差导致更新失败,混合精度训练的策略有效地缓解了舍入误差的问题

  • 损失缩放
    尽管使用了混合精度训练,还是会存在无法收敛的情况,原因是激活梯度的值太小,造成了下溢出。损失缩放是指在执行反向传播之前,将损失函数的输出乘以某个标量数(论文建议从8开始)。 乘性增加的损失值产生乘性增加的梯度更新值,提升许多梯度更新值到超过FP16的安全阈值2^-24。 只要确保在应用梯度更新之前撤消缩放,并且不要选择一个太大的缩放以至于产生inf权重更新(上溢出) ,从而导致网络向相反的方向发散

bf16/fp32 混合训练因为两种格式在范围上对齐了,并且 bf16 比 fp16 的范围更大,所以要比 fp16/fp32 混合训练稳定性更高

2. gradient checkpointing

gradient checkpointing(梯度检查点)的工作原理是在反向传播时重新计算深度神经网络的中间值(通常情况是在前向传播时存储的)。这个策略是用时间(重新计算这些值两次的时间成本)来换空间(提前存储这些值的内存成本)

3. Xformers

Xformers 应该是目前社区知名度最高的优化加速方案了,所谓 Xformers 指的是该库将各种transformer 架构的模型囊括其中。注意,该库仅适用于N卡,特点是加速图片生成并降低显存占用,代价是输出图像不稳定,有可能比不开Xformers略差。各种transformer变体可以参考 A Survey of Transformers.

参考

  • 彻底搞懂float16与float32的计算方式
  • pytorch模型训练之fp16、apm、多GPU模型、梯度检查点(gradient checkpointing)显存优化等
  • facebookresearch/xformers

你可能感兴趣的:(自然语言处理,stable,diffusion,AIGC,人工智能)