pytorch的模型保存与加载(注意的点)

pytorch的模型保存与加载(注意的点)

只保存参数(省空间)与读取方式

#保存模型...的参数,pt/pkl/pth没啥区别
torch.save(model.state_dict(),'./model.pt')
#加载模型
new_model = Model()
model_dict = torch.load('./model.pt')
new_model.load_state_dict(model_dicts)

保存模型(省力气)与读取方式

#保存模型
torch.save(model,'./model.pt')
#加载模型
new_model = torch.load('./model.pt')

需要注意的点!!!

保存模型并不是真的保存了完整模型,而是保存了模型所在文件名与模型类名。读取.pt(或.pkl)文件时,会去找原模型文件,比如模型名称为:LSTM,存放模型的文件名为Mymodel.py,那么读取.pt文件时,程序会访问当前程序目录下的Mymodel.LSTM,然后形成真正的能够自动计算的网络模型。最终结果是什么?1. 你使用保存模型的方式后想传送模型到别的地方,那么只传pt/pkl文件是不行的,你得将原模型所在的文件一同传送。2.你一旦修改模型中forward的代码,那么就算你读取早就保存好pt文件做测试,也会使用新的forward函数。这也就是为啥很多人都说使用保存参数的方式比较好,保存参数不仅能节省空间,还能自然而然地帮你避开一些保存模型方式下才需要注意的点,虽然要重新构建相同模型网络,但也灵活少坑。

你可能感兴趣的:(pytorch,深度学习,人工智能)