Pytorch如何保存和加载模型参数

pytorch 保存和加载模型的方法有两种:

1.保存网络的参数

import torch
#导入模块

net=Net()
#创建网络,当然还需要损失函数梯度等省略


PATH='state_dict_model.pth'
#先建立路径
torch.save(net.state_dict(),PATH)
#保存:可以是pth文件或者pt文件

model=Net()
model.load_state_dict(torch.load(PATH))
#载入保存的模型参数
model.eval()
#不启用 BatchNormalization 和 Dropout

2.保存整个网络

import torch

PATH = "entire_model.pt"
# Save
torch.save(net, PATH)

# Load
model = torch.load(PATH)
model.eval()

Remember too, that you must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results.

你可能感兴趣的:(python相关学习,卷积神经网络,深度学习,python,深度学习,网络)