amp(混合精度)训练-torch

(1) 导入模块

from torch.cuda.amp import autocast as autocast, GradScaler

(2) 创建amp梯度缩放器

scaler = GradScaler()

(3) 训练-求loss-反传

if opt['train']['enable_fp16']:
   with autocast():
       # model
       output= model(input)
       # loss
       train_loss = loss(output,label)
       # loss backward
       scaler.scale(train_loss).backward()
       scaler.unscale_(optimizer)
       scaler.step(optimizer)
       scaler.update()

你可能感兴趣的:(命令记录,pytorch,深度学习,机器学习)