PyTorch模型保存torch.save()与加载torch.load()


Author :Horizon Max

编程技巧篇:各种操作小结

机器视觉篇:会变魔术 OpenCV

深度学习篇:简单入门 PyTorch

神经网络篇:经典网络模型

算法篇:再忙也别忘了 LeetCode


文章目录

    • 模型保存
      • 1)保存全部
      • 2)保存部分
    • 模型加载
      • 1)加载全部
      • 2)加载部分
    • 加载模型GPU和CPU转换
      • 1)GPU 转 CPU
      • 2)CPU 转 GPU
      • 3)GPU之间转换

模型保存

torch.save()

1)保存全部

保存整个模型

torch.save(model, path)

2)保存部分

只保存模型训练的参数(不包括网络结构)

torch.save(model.state_dict(), path)

模型加载

torch.load()

1)加载全部

加载整个模型

model = torch.load(path)

2)加载部分

只加载模型训练的参数(不包括网络结构)

model = Net()    # 网络
model = model.to(device)
checkpoint = torch.load(path)
model.load_state_dict(checkpoint)

加载模型GPU和CPU转换

使用 map_location 参数对 GPUCPU 进行转化

1)GPU 转 CPU

model = torch.load(PATH, map_location='cpu')

2)CPU 转 GPU

model = torch.load(PATH, map_location=lambda storage, loc: storage.cuda(0))

3)GPU之间转换

torch.load(PATH, map_location={'cuda:0':'cuda:1'})    # GPU0转到GPU1

你可能感兴趣的:(各种操作小结,PyTorch,torch.save,torch.load,深度学习)