pytorch模型加载和保存

1.只保存模型参数

保存模型参数

torch.save(net.state_dict(), 'net_parameter.pkl')

加载模型参数

#定义模型结构
model = create_net()
#加载模型参数
model.load_state_dict(torch.load('net_parameter.pkl'))

2.保存完整模型

即保存模型结构又保存模型参数

torch.save(net, 'net_model.pkl')

加载模型:

net_loaded = torch.load('net_model.pkl')

你可能感兴趣的:(pytorch,pytorch,深度学习,人工智能)