PyTorch 半精度训练踩坑

背景

因为显卡显存不够,所以了解了一些PyTorch节省显存的方法:
拿什么拯救我的 4G 显卡 - OpenMMLab的文章 - 知乎
https://zhuanlan.zhihu.com/p/430123077

其中有一种方法叫做半精度训练:
PyTorch的自动混合精度(AMP) - Gemfield的文章 - 知乎
https://zhuanlan.zhihu.com/p/165152789

什么是半精度训练呢?
PyTorch中默认创建的tensor都是FloatTensor类型。而在PyTorch中,一共有10种类型的tensor:

torch.FloatTensor (32-bit floating point)
torch.DoubleTensor (64-bit floating point)
torch.HalfTensor (16-bit floating point 1)
torch.BFloat16Tensor (16-bit floating point 2)
torch.ByteTensor (8-bit integer (unsigned))
torch.CharTensor (8-bit integer (signed))
torch.ShortTensor (16-bit integer (signed))
torch.IntTensor (32-bit integer (signed))
torch.LongTensor (64-bit integer (signed))
torch.BoolTensor (Boolean)

所谓半精度训练,就是用torch.HalfTensor进行训练,以FP16的方式存储数据(本来是FP32),从而节省显存。

使用半精度训练的方式也很简单:
参考文章pytorch 使用amp.autocast半精度加速训练

即使用autocast + GradScaler

  1. autocast
    需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的
from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

for input, target in data:
    optimizer.zero_grad()

    # 前向过程(model + loss)开启 autocast
    with autocast():
        output = model(input)
        loss = loss_fn(output, target)

    # 反向传播在autocast上下文之外
    loss.backward()
    optimizer.step()
  1. GradScaler
    GradScaler就是梯度scaler模块,需要在训练最开始之前实例化一个GradScaler对象。
    因此PyTorch中经典的AMP使用方式如下:
from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# 在训练最开始之前实例化一个GradScaler对象
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # 前向过程(model + loss)开启 autocast
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

问题

训练时loss出现了nan。
一步步向上追溯,发现T5.encoder就已经输出了nan。
Google了一下,发现并不是我一个人遇到这个问题:
https://github.com/huggingface/transformers/issues/4287

可能的解决办法

1. 取消半精度训练
简单直接,我就用了这种方式
2. 其他思路

  • 计算loss 时,出现了除以0的情况
  • loss过大,被半精度判断为inf
  • 网络参数中有nan,那么运算结果也会输出nan

参考:解决pytorch半精度amp训练nan问题

你可能感兴趣的:(PyTorch 半精度训练踩坑)