pytorch保存、加载模型, 并将网络模型.pt保存为ONNIX

pytorch的模型和参数是分开的,可以分别保存或加载模型和参数。

简单说,原先的net在保存之前,要eval一下,load之后的net也要eval一下,把所有参数freeze掉。才保证两个net完全相同(输入相同tensor得到完全一致的结果),具体原因参见:pytorch模型的保存与加载注意事项

pytorch有两种模型保存方式:

一、保存整个神经网络的的结构信息和模型参数信息,save的对象是网络net。

二、只保存神经网络的训练模型参数,save的对象是net.state_dict()。
对应两种保存模型的方式,pytorch也有两种加载模型的方式。对应第一种保存方式,加载模型时通过torch.load(‘.pth’)直接初始化新的神经网络对象;对应第二种保存方式,需要首先导入对应的网络,再通过net.load_state_dict(torch.load(‘.pth’))完成模型参数的加载。

在网络比较大的时候,第一种方法会花费较多的时间。

Pytorch两种模型保存方式
1,只保存模型参数

# 保存
torch.save(model.state_dict(), 

你可能感兴趣的:(Pytorch学习,pytorch,python)