Apex官网:https://nvidia.github.io/apex/amp.html
这篇博客讲的非常好
PyTorch必备神器 | 唯快不破:基于Apex的混合精度加速
使用pip安装后会出错
TypeError: Class advice impossible in Python3. Use the @Implementer
class decorator instead.
解决方法:
$ pip uninstall apex
$ git clone https://www.github.com/nvidia/apex
$ cd apex
$ python setup.py install
核心代码:
from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # “欧一”,不是“零一”
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
原始训练代码:
import torch
ngpu=2
def traiin():
model = torch.nn.Linear(D_in, D_out).cuda()
model = torch.nn.DataParallel(model, device_ids=[i for i in range(ngpu)])
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for img, label in dataloader:
out = model(img.half())
loss = LOSS(out, label)
loss.backward()
optimizer.step()
optimizer.zero_grad()
#此时采用全精度32位来训练
半精度训练:
import torch
ngpu=2
def traiin():
model = torch.nn.Linear(D_in, D_out).cuda().half()
model = torch.nn.DataParallel(model, device_ids=[i for i in range(ngpu)])
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for img, label in dataloader:
out = model(img.half())
loss = LOSS(out, label)
loss.backward()
optimizer.step()
optimizer.zero_grad()
#此时采用半精度16位来训练
显存基本可以降低为原来的一半,但训练速度降低,可能原因是,CUDNN只支持float32加速,半精度后,将不能加速
混合精度训练:
import torch
ngpu=2
def train():
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
#设置混合精度模式为O1(欧1,不是零1,后面会解释各个模式区别)
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
model = torch.nn.DataParallel(model, device_ids=[i for i in range(ngpu)])
for img, label in dataloader:
out = model(img)
loss = LOSS(out, label)
#将loss进行缩放,防止溢出
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
optimizer.zero_grad()
def save_model(self, epoch):
if self.mixed_precision:
import apex.amp as amp
amp_state_dict = amp.state_dict()
else:
amp_state_dict = None
checkpoint = {
'epoch': epoch,
'params': self.params,
'model': self.model.module.state_dict() if self.ngpu > 1 else self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'amp': amp_state_dict
}
torch.save(checkpoint, os.path.join(self.expdir,'model.pt'))
def load_model(self, checkpoint):
state_dict = torch.load(checkpoint)
self.model.load_state_dict(state_dict['model'])
if self.mixed_precision:
import apex.amp as amp
amp.load_state_dict(state_dict['amp'])
注意:
1.模型在amp.initialize前必须加载到GPU上。
2.amp.initialize前不能对模型进行任何分布式操作,如torch.nn.DataParallel必须放在之后。
opt_level | 解释 |
---|---|
O0 | 纯 FP32 训练,可以作为 accuracy 的 baseline |
O1 | 混合精度训练(推荐使用),根据黑白名单自动决定使用 FP16(GEMM, 卷积)还是 FP32(Softmax)进行计算 |
O2 | 几乎FP16混合精度训练,不存在黑白名单,除了 Batch Norm,几乎FP16 计算 |
O3 | 纯 FP16 训练,很不稳定,但是可以作为 speed 的 baseline |
参考:
PyTorch必备神器 | 唯快不破:基于Apex的混合精度加速 [CSDN]
Apex [官网]
Apex混合精度加速 [码农网]