Pytorch-模型的保存于加载

简介

Pytorch中的序列化和反序列化

  1. troch.save
    主要参数
  • obj:对象
  • f:输出路径
  1. torch.load
    主要参数
  • f:文件路径
  • map_location:指定存放位置,cpu或者gpu

对于保存有两种方法:
1.保存整个Moucle, torch.save(net,path)
2.保存模型的参数:
state_dictt=net.state_dict()
torch.save(state_dict,path)

#方式1加载模型
path_model='./model.pkl'
net_load=torch.load(path_model)

#方式2加载模型
path_state_dict="./model_state_dict.pkl"
sate_dict_load=torch.load(path_state_dict)

net.load_dict(state_dict_load)

断点续训练-checkpoint

需要保存那些信息?
Pytorch-模型的保存于加载_第1张图片
只有模型和优化器的参数需要保存,此外还需要记录epoch
Pytorch-模型的保存于加载_第2张图片

checkpoint_interval = 5
#中间省略了若干训练的代码
#保存check_point
if (epoch+1) % checkpoint_interval == 0:                         
                                                                 
    checkpoint = {"model_state_dict": net.state_dict(),          
                  "optimizer_state_dict": optimizer.state_dict(),
                  "epoch": epoch}                                
    path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)  
    torch.save(checkpoint, path_checkpoint)   

#加载check_point
path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)

net.load_state_dict(checkpoint['model_state_dict'])

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

start_epoch = checkpoint['epoch']

scheduler.last_epoch = start_epoch                   

你可能感兴趣的:(Pytorch学习笔记)