[Pytorch] 保存模型与加载模型

1、保存模型

# 定义模型
model = BPNetModel(n_feature=n_feature,n_hidden=n_hidden,n_output=n_output) #调用网络

# 保存模型
torch.save(model, 'BPNetModel0.pth')

2、加载模型

import torch

## 读取模型
model = torch.load('BPNetModel0.pth')

3、保存模型参数 

 #调用网络
model = BPNetModel(n_feature=n_feature,n_hidden=n_hidden,n_output=n_output)

# 保存模型
torch.save({'model': model.state_dict()}, 'BPNetModel0.pth')

 4、加载参数

# 读取模型
state_dict = torch.load('model_name.pth')
model.load_state_dict(state_dict['model'])

你可能感兴趣的:(Pytorch,pytorch,人工智能,python)