PyTorch 保存与加载模型

保存模型与加载模型

常用的两种保存与加载模型方式

# 1.保存整个网络
torch.save(net, PATH)

#针对上面保存方法,加载的方法是:
model_dict=torch.load(PATH)

# 如果有多块GPU,训练和测试使用的不是同一块GPU,则加载的方法是
model_dict=torch.load(PATH, map_location = {'cuda:3', 'cuda:0'})

# 上面'cuda:3'是训练时使用的GPU编号,'cuda:0'是测试时使用的GPU编号
# 2.保存网络中的参数, 速度快,占空间少
torch.save(net.state_dict(),PATH)

#针对上面的保存方法,加载的方法是:
model_dict=model.load_state_dict(torch.load(PATH))

你可能感兴趣的:(python,深度学习)