Pytorch 模型加载和保存

最近在使用pytorch编写深度学习代码,记录一些开发过程中遇到的问题和解决办法。

这一篇主要介绍pytorch模型保存和加载过程中遇到的问题和解决办法。

注意:

  • 在进行预测之前,必须调用 model.eval() 方法来将 dropout 和 batch normalization 层设置为验证模型。否则,只会生成前后不一致的预测结果。
  • load_state_dict() 方法必须传入一个字典对象,而不是对象的保存路径,也就是说必须先反序列化字典对象,然后再调用该方法,也是例子中先采用 torch.load() ,而不是直接 model.load_state_dict(PATH)
  • TheModelClass为需要使用的模型。可以通过import包含该模型定义的.py文件或者在当前文件中重新定义模型。参数(*args, **kwargs)需要根据实际情况来初始化,这里只是例子。

        保存和加载模型都是采用非常直观的语法并且都只需要几行代码即可实现。这种实现保存模型的做法将是采用 Python 的 pickle 模块来保存整个模型,这种做法的缺点就是序列化后的数据是属于特定的类和指定的字典结构,原因就是 pickle 并没有保存模型类别,而是保存一个包含该类的文件路径,因此,当在其他项目或者在 refactors 后采用都可能出现错误。即需要在使用的地方import原文件或者重新定义model类。

1.加载/保存整个模型

保存:

torch.save(model, PATH)

加载:

model = torch.load(PATH)
model.eval()

2.加载/保存状态字典

保存的代码:

torch.save(model.state_dict(), PATH)

加载的代码:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

你可能感兴趣的:(深度学习,pytorch,python,深度学习)