pytorch 基于 apex.amp 的混合精度训练:原理介绍与实现

文章目录

    • 1. 混合精度训练介绍
      • 1.1 FP16 与 FP32
      • 1.2 为什么要使用混合精度训练?
      • 1.3 使用fp16带来的问题及解决方法
    • 2. apex 介绍与安装
    • 3. apex.amp 的使用
      • 3.1 三行代码实现 amp
      • 3.2 参数配置
      • 3.3 amp测试:MNIST 手写数字识别
    • 4. 参考资料推荐

1. 混合精度训练介绍

所谓天下武功,唯快不破。我们在训练模型时,往往受制于显存空间只能选取较小的 batch size,导致训练时间过长,使人逐渐烦躁。那么有没有可能在显存空间不变的情况下提高训练速度呢?混合精度训练(Mixed Precision)便油然而生。

常见的模型加速的方法有很多,混合精度是其中一种。

pytorch 基于 apex.amp 的混合精度训练:原理介绍与实现_第1张图片

1.1 FP16 与 FP32

  • fp16(float16):Half-precision floating-point format 半精度浮点数
  • fp32(float32):单精度浮点数
  • fp64(float64):双精度浮点数

FP16 与 FP32 的存储方式和精度参考博客:fp16与fp32简介与试验

混合精度训练的精髓在于在内存中用 fp16 做储存和乘法从而加速计算,用 fp32 做累加避免舍入误差。

1.2 为什么要使用混合精度训练?

神经网络框架的计算核心是Tensor,pytorch 中定义一个Tensor其默认类型是fp32。目前大多数的深度学习模型使用的是 fp32 进行训练,而混合精度训练的方法则通过 fp16 进行深度学习模型训练,从而减少了训练深度学习模型所需的内存,同时由于 fp16 的运算比 fp32 运算更快,从而也进一步提高了硬件效率。总之,混合16位和32位的计算可以节约GPU显存和加速神经网络训练。

此外,硬件的发展同样也推动着模型计算的加速,随着Nvidia张量核心(Tensor Core)的普及,16bit计算也一步步走向成熟,低精度计算也是未来深度学习的一个重要趋势。

总结一下就是:省存储,省传播,省计算

1.3 使用fp16带来的问题及解决方法

参考博客:【PyTorch】唯快不破:基于Apex的混合精度加速,PyTorch的自动混合精度(AMP)

fp16 的优势是存储小、计算快、更好的利用CUDA设备的Tensor Core。因此训练的时候可以减少显存的占用(可以增加batchsize了),同时训练速度更快。

fp16 的劣势是:数值范围小(更容易Overflow / Underflow)、舍入误差(Rounding Error,导致一些微小的梯度信息达不到16bit精度的最低分辨率。比如反向求导中很接近0的小数值梯度用fp16表示后变为0,从而导致梯度消失,训练停滞。

可见,当 fp16有优势的时候就用 fp16,而为了消除 fp16 的劣势,有两种解决方案:

  • 梯度缩放,通过放大 loss 的值来防止梯度的 underflow(这只是BP的时候传递梯度信息使用,真正更新权重的时候还是要把放大的梯度再unscale回去)。也就是将loss值放大k倍,根据链式法则,反向传播中的梯度也会放大k倍,原来不能被fp16表示的数就可以被fp16表示。
  • 由 pytorch 自动决定什么时候用fp16,什么时候用fp32, 一般用 fp16 做储存和乘法从而加速计算,用 fp32 做累加避免舍入误差。如在卷积和全连接操作中用fp16,在 Softmax 操作中用fp32 , 这是 amp 自动设定和计算的。

在神经网络处理器NPU中,在前向计算,反向求导,梯度传输时候用fp16,参数更新阶段将fp16参数加到参数的fp32副本上。下一轮迭代时,将fp32副本上的参数转为fp16,用于前向计算。二者之间的转换为NPU内部自动实现的,操作者不可见也无法干预。

Loss Scale 分为静态和动态 Loss Scale,动态 Loss Scale 会自动更改 Loss Scale 的缩放倍数。

2. apex 介绍与安装

apex 的全称是 A PyTorch Extension ,其实就是一种 pytorch 的拓展插件,其本身与混合精度并无关系。apex 是 Nvidia 开发的基于 PyTorch 的混合精度训练加速神器,因此 Apex 必须在GPU上使用,而不能在CPU中使用。

apex包的nvidia官网介绍:Mixed-Precision Training of Deep Neural Networks

amp 的全称是 auto mixed precision,自动混合精度,是一个用来支持模型训练在pytorch框架下使用混合精度进行加速训练的拓展插件之类的库。它最核心的东西在于低精度 fp16 , 它能够提供一种可靠友好的方式使得模型在 fp16 精度下进行训练。

从 apex 中引入 amp 的方法是: from apex import amp
pytorch 原生支持的 amp 的使用方法是:from torch.cuda.amp import autocast as autocast, GradScaler

apex安装过程参考博客: PyTorch apex库安装(Linux系统)

3. apex.amp 的使用

3.1 三行代码实现 amp

只需要在程序中加入这几行代码即可(引自apex文档):

from apex import amp
model, optimizer = amp.initialize(model, optimizer,opt_level="O1",loss_scale=128.0) 
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()    

amp 是 pytorch 的自动混合精度,具体介绍可参考:https://zhuanlan.zhihu.com/p/165152789

scale 是缩放的意思,通过放大loss的值来防止梯度下溢,不过这只是BP的时候传递梯度信息使用,真正更新权重的时候还是要把放大的梯度再unscale回去。

3.2 参数配置

opt_level 参数:

  • O0:纯FP32训练,可以作为accuracy的baseline
  • O1:混合精度训练,根据黑白名单自动决定使用 FP16 还是 FP32 进行计算
  • O2:“几乎FP16”混合精度训练,不存在黑白名单,除了Batch norm,几乎都是用FP16计算
  • O3:纯FP16训练,很不稳定,但是可以作为speed的baseline

说明:

  • 推荐优先使用 opt_level=‘O2’, loss_scale=128.0 的配置进行amp.initialize
  • 若无法收敛推荐使用 opt_level=‘O1’, loss_scale=128.0 的配置进行amp.initialize
  • 若依然无法收敛推荐使用 opt_level=‘O1’, loss_scale=None 的配置进行amp.initialize

3.3 amp测试:MNIST 手写数字识别

代码:

import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

############################
# edit this for amp
from apex import amp
############################


parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=8, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=1, metavar='N',
                    help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                    help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--dry-run', action='store_true', default=False,
                    help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--save-model', default=True,
                    help='For Saving the current Model')
args = parser.parse_args()

device = torch.device('cuda:0')
torch.manual_seed(args.seed)
train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
cuda_kwargs = {'num_workers': 1,'pin_memory': True,'shuffle': True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
dataset1 = datasets.MNIST('./ms', train=True, download=True,transform=transform)
dataset2 = datasets.MNIST('./ms', train=False,transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)

        #################################################
        # edit this for amp
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
        # loss.backward()
        #################################################

        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)


##############################################################################################3
#add this for amp
opt_level = 'O2'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level,loss_scale=128.0)
###############################################################################################


scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    scheduler.step()
if args.save_model:
    torch.save(model.state_dict(), "mnist_cnn.pt")
    # torch.save(model, "mnist_cnn.pt")  会报错,只能保存模型参数,不能保存模型

【注】经过 Apex 的 model 不能貌似保存模型,只能保存模型参数。因此不能用 torch.save(model, ‘model.pt’) 保存模型,只能用 torch.save(model.state_dict(), ‘model.pt’) 保存模型参数。原因不详。

4. 参考资料推荐

【PyTorch】唯快不破:基于Apex的混合精度加速

PyTorch的自动混合精度(AMP)

fp16与fp32简介与试验

pytorch原生支持的apex混合精度和nvidia apex混合精度AMP技术加速模型训练效果对比

你可能感兴趣的:(#,混合精度计算,python,pytorch,神经网络,深度学习,人工智能)