Pytorch中保存模型,并在测试集上训练前加载模型。

保存模型的参数

我们在之前定义模型名为"model", 这行代码就是在保存模型目前的参数为"parameters.pkl"文件。这里的"parameters.pkl"可以改为你需要的路径。


torch.save(model.state_dict(), 'parameters.pkl')


注意此处,模型被存为了state_dict形式,state_dict其实就是一个词典,将模型的每一层与其对应的参数张量保存在了一起。

加载模型

加载模型也非常简单,只需要:


model.load_state_dict(torch.load('parameters.pkl'))
model.eval() #开启evaluation模式


之后用现在的模型来做预测就可以了。

Reference

https://pytorch.org/tutorials/beginner/saving_loading_models.html

你可能感兴趣的:(Pytorch实战)