[Pytorch] 模型的保存torch.save()与加载torch.load()

在模型训练验证之后,可以把训练好的模型进行保存
后续进行测试的时候,可以直接加载模型进行测试

保存模型

将模型保存下来,如果后续有数据了可以继续训练,或者直接加载进行测试。

  1. 只保存模型的参数(一般只保存参数即可)
torch.save(model.state_dict(), PATH)  # 保存模型参数
  1. 保存整个模型(较大)
torch.save(model,PATH)  # 保存整个模型

加载模型

  1. 只加载模型的参数(先创建模型、并加载参数,再恢复得到模型)
model = MyModel().to(device)
checkpoint = torch.load(config['save_path'])  # 先加载参数
model.load_state_dict(checkpoint)  # 再让模型加载参数, 恢复得到模型
  1. 加载整个模型
model = torch.load(PATH)

在torch.load()中,常常会使用map_location进行cpu和gpu的转化。
(1)【GPU->CPU】
比如模型训练的时候是在GPU上进行并保存的,测试的时候却想在CPU上进行训练:

model = torch.load(PATH, map_location='cpu')

(2)【CPU->GPU】
转为GPU需要指明在哪块GPU上,例如转到第0块GPU:

model = torch.load(PATH, map_location=lambda storage, loc: storage.cuda(0))

(3)【GPU->GPU】
不同块GPU的转换,例如第1块转到第0块:

torch.load(PATH, map_location={'cuda:1':'cuda:0'})

Note:如果只保存了模型参数,就加载模型参数;如果保存了整个模型,就加载整个模型;上述两组一一对应。

参考:

  1. pytorch 状态字典:state_dict使用详解:https://blog.csdn.net/Bruce_0712/article/details/111990905
  2. torch.load_state_dict()函数的用法总结:https://blog.csdn.net/ChaoMartin/article/details/118686268
  3. pytorch cpu与gpu load时相互转化 torch.load(map_location=):https://blog.csdn.net/bc521bc/article/details/85623515

你可能感兴趣的:(model-pytorch,pytorch,深度学习,python)