作为一名算法工程师,我们经常会遇到训练网络的事情,当前训练网络的整个过程基本上都是在N卡上面执行的,当我们的数据集比较大时,训练网络会耗费大量的时间。由于我们需要使用反向传播来更新具有细微变化的权重,因而我们在训练网络的过程中通常会选用FP32类型的数据和权重。说了这么多,那么混合精度到底是什么呢,有什么用呢?
简而言之,所谓的混合精度训练,即当你使用N卡训练你的网络时,混合精度会在内存中用FP16做储存和乘法从而加速计算,用FP32做累加避免舍入误差。它的优势就是可以使你的训练时间减少一半左右。它的缺陷是只能在支持FP16操作的一些特定类型的显卡上面使用,而且会存在溢出误差和舍入误差。
区别:
优点1-FP16计算速度更快、更加节约内存
上图展示了FP16和FP32在内存消耗上面的不同之处。通过观察上图我们可以得出:
优点2-FP16可以使用上特定显卡中专门为加速所设计的Tensor Core
上图展示了执行卷积的过程(乘操作和加操作)。使用FP16执行成操作,然后使用FP16或者FP32执行乘操作。与使用FP32计算相比,在Volta V100(该架构中存在Tensor Core,支持FP16操作)上面可以获得8倍的性能提速,最终达到125TFlops的扇出。
缺点1-FP16会带来梯度溢出错误
Grad Overflow / Underflow,即梯度溢出。由于FP16的动态范围是 5.96 × 1 0 − 8 < x < 65504 5.96 \times 10^{-8}
缺点2-FP16会带来舍入误差
舍入误差,即当梯度过小,小于当前区间内的最小间隔时,该次梯度更新可能会失败,具体的细节如下图所示,由于更新的梯度值超出了FP16能够表示的最小值的范围,因此该数值将会被舍弃,这个权重将不进行更新。
解决方案:
使用FP32训练代码如下所示:
# coding=utf-8
import torch
N, D_in, D_out = 64, 1024, 512
x = torch.randn(N, D_in, device=“cuda”)
y = torch.randn(N, D_out, device=“cuda”)
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for t in range(500):
y_pred = model(x)
loss = torch.nn.functional.mse_loss(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
使用FP16训练代码如下所示,仅仅需要在原始的Pytorch代码中增加3行代码,你就可以体验到极致的性能加速啦。
# coding=utf-8
import torch
N, D_in, D_out = 64, 1024, 512
x = torch.randn(N, D_in, device=“cuda”)
y = torch.randn(N, D_out, device=“cuda”)
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
model, optimizer = amp.initialize(model, optimizer, opt_level=“O1”)
for t in range(500):
y_pred = model(x)
loss = torch.nn.functional.mse_loss(y_pred, y)
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
1、model, optimizer = amp.initialize(model, optimizer, opt_level=“O1”)
这行代码的主要作用是对模型和优化器执行初始化操作,方便后续的混合精度训练。其中opt_level表示优化的等级,当前支持4个等级的优化,具体的情况如下图所示:
注意事项:
1、cast_model_type表示的是模型的输入类型。当前支持的类型包括torch.float32和torch.float16;
2、patch _torch_function表示的是根据不同函数的输入数据要求获得一个最优的输入类型。GEMM和Convolution等运算可以使用FP16快速的获得最终的结果,由于softma x/exponentiation/pow等运算需要较高的精度,所以选择使用FP32来计算。当前支持的类型包括False和True。
3、keep_batchnorm _fp32表示的是是否需要对网络中BN层执行特殊处理。由于网络中的BN层会影响数据的分布情况,从而进一步影响网络的训练过程,因此需要认真的去处理这个类型的层。当该层使用FP32时,网络的训练过程会更加稳定。当前支持的类型包括False和True。
4、master_weight 表示的是网络在训练过程中部分参数使用FP32来表示,部分参数使用FP16来表示。上图中蓝色的框表示FP32类型,绿色的框表示FP16类型,FP32在转化为FP16的过程中会进行备份(master_0),optimizer都是使用FP32来表示,而model部分中部分参数是FP16类型,部分参数是FP32类型,梯度更新的过程通常是在master上面执行。
5、loss_scale表示是是否需要执行损失放大操作。1.0表示不需要执行损失放大操作,dynamic表示需要执行损失放大操作。
NVIDIA官方给出的使用规则如下所示:
2、 with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
这行代码的主要作用是在反向传播前进行梯度放大来进行更新,在反向传播后进行梯度缩放,返回原来的值,但是可以很好的解决由于梯度值太小模型无法更新的问题。具体的细节如下图所示:
上图展示了FP16在计算的过程中由于梯度值太小,超出了FP16能表示的下限值,因而无法进行权重更新,导致网络不收敛。
上图展示了使用损失方法(Scaled Loss)的方法来很好的解决这个问题,即在反向传播之前,给这些比较少的数值乘上 2 k 2^{k} 2k,即将其扩大 2 k 2^{k} 2k倍,将其调整到FP16能够支持的一个合理的范围内,那么FP16就可以对这个比较小的梯度增量值执行更新,这样就可以很好的解决这个问题。
上图展示了对反向传播之后的结果之后后处理的过程,由于我们为了解决方向传播之前梯度数值太小而将它扩大 2 k 2^{k} 2k倍,那么这样计算之后就相当于我们认为的将梯度值增加了 2 k 2^{k} 2k倍,为了获得准确的权重值,我们需要在反向传播之后除以 2 k 2^{k} 2k,整个过程在optimizer.step()执行之前。
上图展示了使用混合训练在多个经典模型上面的加速效果。在BERT模型上,使用混合精度训练可以获得4倍的提速。换句话说,我们原本需要4天才能训练好的模型,现在1天就可以训练出来,而且能够达到几乎相同的精度级别。这在很多情况下还是挺有用的,这个方法在减少模型训练时间的同时可以节省更多的电费,除此之外,可以节约算法工程师们的时间,从而提高他们的工作效率。
上图展示了使用混合精度训练的模型的精度。通过观察我们可以得出以下的结论:混合精度训练在提升训练速度的同时可以达到和FP32训练同样的精度。
通过仔细理解上面的内容,你应该会对混合精度训练有了一个全新的认识。所谓的混合精度训练,即当你使用N卡训练你的网络时,混合精度会在内存中用FP16做储存和乘法从而加速计算,用FP32做累加避免舍入误差。它的优势就是可以使你的训练时间减少一半左右。它的缺陷是只能在支持FP16操作的一些特定类型的显卡上面使用,而且会存在溢出误差和舍入误差。总而言之,混合精度训练可以在保证精度的同时极大的提升你的训练速度,如果你习惯使用pytorch来训练网络,那你就可以获得极致的训练速度啦。当前混合精度训练仍然存在着一些限制条件,首先,你的硬件设备需要支持FP16计算;其次,你的硬件设备需要具有Tensor_Core单元(这仅仅存在于一些新架构的N卡上面);接着,当前的仅有少量的深度学习框架支持混合精度训练(Pytorch);最后,混合精度不仅仅可以用在网络训练的过程中,同样也可以将它应用在网络推理过程中执行加速。随着越来越多的硬件设备支持FP16计算之后,混合精度训练和推理应该会成为一个首选,我相信越来越多的训练和推理框架都会在短期内逐渐支持混合精度训练。
[1] 参考博客
[2] NVIDIA参考资料
[3] GTC_2019
[4] apex
[1] 该博客是本人原创博客,如果您对该博客感兴趣,想要转载该博客,请与我联系(qq邮箱:[email protected]),我会在第一时间回复大家,谢谢大家的关注.
[2] 由于个人能力有限,该博客可能存在很多的问题,希望大家能够提出改进意见。
[3] 如果您在阅读本博客时遇到不理解的地方,希望您可以联系我,我会及时的回复您,和您交流想法和意见,谢谢。
[4] 本人业余时间承接各种本科毕设设计和各种小项目,包括图像处理(数据挖掘、机器学习、深度学习等)、matlab仿真、python算法及仿真等,有需要的请加QQ:1575262785详聊,备注“项目”!!!