【Pytorch】14. 保存和加载模型

本节会介绍在pytorch中如何保存和加载模型,这样就可以使用之前训练好的模型进行预测或者继续训练了

state_dict可以查看储存在model中的参数,weights和bias的矩阵

print("Our model: \n\n", model, '\n')
print("The state dict keys: \n\n", model.state_dict().keys())
print (model.state_dict())

保存模型最简单的方法就是用torch.save, 例如我们可以把它存到checkpoint.pth.

torch.save(model.state_dict(), 'checkpoint.pth')

然后可以使用torch.load加载模型

state_dict = torch.load('checkpoint.pth')
print(state_dict.keys())

如果要把这个state dict加载到网络中去,就要用model.load_state_dict(state_dict)

model.load_state_dict(state_dict)

上面的使用方法看起来非常方便,但是实际应用起来却有点复杂,加载state dict只适用于模型结构和checkpoints的结构完全一致,如果不一致就会报错

因此我们可以定义checkpoint的结构

checkpoint = {'input_size': 784,
              'output_size': 10,
              'hidden_layers': [each.out_features for each in model.hidden_layers],
              'state_dict': model.state_dict()}

torch.save(checkpoint, 'checkpoint.pth')

def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = fc_model.Network(checkpoint['input_size'],
                             checkpoint['output_size'],
                             checkpoint['hidden_layers'])
    model.load_state_dict(checkpoint['state_dict'])
    
    return model
    
model = load_checkpoint('checkpoint.pth')
print(model)

查看完整代码参考
https://github.com/udacity/deep-learning-v2-pytorch.git中
intro-to-pytorch的Part 6

本系列笔记来自Udacity课程《Intro to Deep Learning with Pytorch》

全部笔记请关注微信公众号【阿肉爱学习】,在菜单栏点击“利其器”,并选择“pytorch”查看

【Pytorch】14. 保存和加载模型_第1张图片

你可能感兴趣的:(pytorch)