pytorch中模型的保存与加载:torch.save(),torch.load()

pytorch保存模型与加载:

模型的保存

torch.save(net,PATH)#保存模型的整个网络,包括网络的整个结构和参数
torch.save(net.state_dict,PATH)#只保存网络中的参数

模型的加载

分别对应上边的加载方法。

model_dict=torch.load(PATH)
model_dict=net.load_state_dict(torch.load(PATH))

你可能感兴趣的:(python,与,pytorch,pytorch,python)