本节会介绍在pytorch中如何保存和加载模型,这样就可以使用之前训练好的模型进行预测或者继续训练了
state_dict
可以查看储存在model中的参数,weights和bias的矩阵
print("Our model: \n\n", model, '\n')
print("The state dict keys: \n\n", model.state_dict().keys())
print (model.state_dict())
保存模型最简单的方法就是用torch.save
, 例如我们可以把它存到checkpoint.pth
.
torch.save(model.state_dict(), 'checkpoint.pth')
然后可以使用torch.load
加载模型
state_dict = torch.load('checkpoint.pth')
print(state_dict.keys())
如果要把这个state dict加载到网络中去,就要用model.load_state_dict(state_dict)
model.load_state_dict(state_dict)
上面的使用方法看起来非常方便,但是实际应用起来却有点复杂,加载state dict只适用于模型结构和checkpoints的结构完全一致,如果不一致就会报错
因此我们可以定义checkpoint的结构
checkpoint = {'input_size': 784,
'output_size': 10,
'hidden_layers': [each.out_features for each in model.hidden_layers],
'state_dict': model.state_dict()}
torch.save(checkpoint, 'checkpoint.pth')
def load_checkpoint(filepath):
checkpoint = torch.load(filepath)
model = fc_model.Network(checkpoint['input_size'],
checkpoint['output_size'],
checkpoint['hidden_layers'])
model.load_state_dict(checkpoint['state_dict'])
return model
model = load_checkpoint('checkpoint.pth')
print(model)
查看完整代码参考
https://github.com/udacity/deep-learning-v2-pytorch.git中
intro-to-pytorch的Part 6
本系列笔记来自Udacity课程《Intro to Deep Learning with Pytorch》
全部笔记请关注微信公众号【阿肉爱学习】,在菜单栏点击“利其器”,并选择“pytorch”查看