(2)Pytorch保存模型权重参数

保存模型权重参数

我们可以选择只保存模型权重参数,或者保存模型结构+权重参数,通常采用前者。此处介绍只保存模型权重的方法

1、只保存模型权重参数

# dir = 'xxxx/resnet18.pth'
import torch
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)

# 保存
torch.save(resnet18.state_dict(), 'xxxx/resnet18.pth')
# 调用
resnet18 = models.resnet18() 
resnet18.load_state_dict(torch.load('xxxx/resnet18.pth'))

2、保存模型权重、优化器权重、epoch信息

dir = 'mymodel.pth'
state = {'net':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
torch.save(state, dir) # 权重参数包括了模型权重、优化器权重、epoch
checkpoint = torch.load(dir) # checkpoint 把之前save的state加载进来
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] + 1

你可能感兴趣的:(pytorch基础知识积累,pytorch,python,深度学习)