pytorch 断点训练,从指定epoch恢复训练

1、保存模型

保存整个模型


torch.save(net, path)

保存权重


state_dict = net.state_dict()
torch.save(state_dict , path)

2、模型训练过程保存


checkpoint = {
        "net": model.state_dict(),
        'optimizer':optimizer.state_dict(),
        "epoch": epoch
    }

3、指定epoch恢复

path_checkpoint = "./models/checkpoint/ckpt_best_1.pth"  # 断点路径 
checkpoint = torch.load(path_checkpoint)  # 加载断点

model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数

optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
start_epoch = checkpoint['epoch']  # 设置开始的epoch

4、完整流程

start_epoch = -1


if RESUME:
    path_checkpoint = "./models/checkpoint/ckpt_best_1.pth"  # 断点路径
    checkpoint = torch.load(path_chec

你可能感兴趣的:(pytorch,深度学习,python)