pytorch使用Apex混合精度加速训练

Apex官网:https://nvidia.github.io/apex/amp.html
这篇博客讲的非常好
PyTorch必备神器 | 唯快不破:基于Apex的混合精度加速

1.安装

使用pip安装后会出错

TypeError: Class advice impossible in Python3. Use the @Implementer
class decorator instead.

解决方法:

$ pip uninstall apex
$ git clone https://www.github.com/nvidia/apex
$ cd apex
$ python setup.py install

2.使用

核心代码:

from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # “欧一”,不是“零一”
with amp.scale_loss(loss, optimizer) as scaled_loss:
		scaled_loss.backward()

例子:

原始训练代码:

import torch
ngpu=2
def traiin():
		model = torch.nn.Linear(D_in, D_out).cuda()
		model = torch.nn.DataParallel(model, device_ids=[i for i in range(ngpu)])
		optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
		for img, label in dataloader:
			out = model(img.half())
			loss = LOSS(out, label)
			loss.backward()
			optimizer.step()
			optimizer.zero_grad()

#此时采用全精度32位来训练

半精度训练:

import torch
ngpu=2
def traiin():
		model = torch.nn.Linear(D_in, D_out).cuda().half()
		model = torch.nn.DataParallel(model, device_ids=[i for i in range(ngpu)])
		optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
		for img, label in dataloader:
			out = model(img.half())
			loss = LOSS(out, label)
			loss.backward()
			optimizer.step()
			optimizer.zero_grad()
#此时采用半精度16位来训练

显存基本可以降低为原来的一半,但训练速度降低,可能原因是,CUDNN只支持float32加速,半精度后,将不能加速

混合精度训练:

import torch
ngpu=2
def train():
		model = torch.nn.Linear(D_in, D_out).cuda()
		optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
		#设置混合精度模式为O1(欧1,不是零1,后面会解释各个模式区别)
		model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
		model = torch.nn.DataParallel(model, device_ids=[i for i in range(ngpu)])
		for img, label in dataloader:
			out = model(img)
			loss = LOSS(out, label)
			#将loss进行缩放,防止溢出
			with amp.scale_loss(loss, optimizer) as scaled_loss:
		    	scaled_loss.backward()
		
			optimizer.step()
			optimizer.zero_grad()
def save_model(self, epoch):
        if self.mixed_precision:
            import apex.amp as amp
            amp_state_dict = amp.state_dict()
        else:
            amp_state_dict = None
        checkpoint = {
     
            'epoch': epoch,
            'params': self.params,
            'model': self.model.module.state_dict() if self.ngpu > 1 else self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'amp': amp_state_dict
        }
        torch.save(checkpoint, os.path.join(self.expdir,'model.pt'))

def load_model(self, checkpoint):
	  state_dict = torch.load(checkpoint)
	  self.model.load_state_dict(state_dict['model'])
	
	  if self.mixed_precision:
	      import apex.amp as amp
	      amp.load_state_dict(state_dict['amp'])

注意:
1.模型在amp.initialize前必须加载到GPU上。
2.amp.initialize前不能对模型进行任何分布式操作,如torch.nn.DataParallel必须放在之后。

opt_level 解释
O0 纯 FP32 训练,可以作为 accuracy 的 baseline
O1 混合精度训练(推荐使用),根据黑白名单自动决定使用 FP16(GEMM, 卷积)还是 FP32(Softmax)进行计算
O2 几乎FP16混合精度训练,不存在黑白名单,除了 Batch Norm,几乎FP16 计算
O3 纯 FP16 训练,很不稳定,但是可以作为 speed 的 baseline

参考:
PyTorch必备神器 | 唯快不破:基于Apex的混合精度加速 [CSDN]
Apex [官网]
Apex混合精度加速 [码农网]

你可能感兴趣的:(深度学习)