pytorch快速上手(5)-----pytorch模型的保存加载与断点恢复训练

模型的保存与加载

PyTorch中的保存(序列化,从内存到硬盘)与反序列化(加载,从硬盘到内存)

  • torch.save主要参数: obj:对象 、f:输出路径
  • torch.load 主要参数 :f:文件路径 、map_location:指定存放位置、 cpu or gpu
    pytorch快速上手(5)-----pytorch模型的保存加载与断点恢复训练_第1张图片

一、常见的模型保存的两种方法:

1、保存整个Module

torch.save(net, path)

pytorch快速上手(5)-----pytorch模型的保存加载与断点恢复训练_第2张图片

2、保存模型参数

state_dict = net.state_dict()
torch.save(state_dict , path)

二、训练过程中自定义保存内容 与 断点恢复训练

#加载恢复
if RESUME:#是否恢复
    path_checkpoint = "./model_parameter/test/ckpt_best_50.pth"  # 断点模型文件路径
    checkpoint = torch.load(path_checkpoint)  # 加载断点
    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数

    start_epoch = checkpoint['epoch']  # 设置开始的epoch
    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
    lr_schedule.load_state_dict(checkpoint['lr_schedule'])#加载lr_scheduler



#保存
for epoch in range(start_epoch + 1, 80):

    optimizer.zero_grad()

    optimizer.step()
    lr_schedule.step()


    if epoch %20 == 19:#每隔20个epoch保存一次模型
        print('epoch:',epoch)
        print('learning rate:',optimizer.state_dict()['param_groups'][0]['lr'])
        #自定义要保存的参数信息
        checkpoint = {
            "net": model.state_dict(),
            'optimizer': optimizer.state_dict(),
            "epoch": epoch,
            'lr_schedule': lr_schedule.state_dict()
        }
        if not os.path.isdir("./model_parameter/test"):
            os.mkdir("./model_parameter/test")
        torch.save(checkpoint, './model_parameter/test/ckpt_best_%s.pth' % (str(epoch)))

在这里插入图片描述
pytorch快速上手(5)-----pytorch模型的保存加载与断点恢复训练_第3张图片

模型层该改变时,再次载入之前模型的权重,只需要 model.load_state_dict(torch.load(PATH), strict = False)

pytorch快速上手(5)-----pytorch模型的保存加载与断点恢复训练_第4张图片
pytorch快速上手(5)-----pytorch模型的保存加载与断点恢复训练_第5张图片

你可能感兴趣的:(【pytorch】,深度学习,pytorch)