PyTorch 模型读取和存储

需要把训练好的模型部署到很多不同的设备。在这种情况下,我们可以把内存中训练好的模型参数存储在硬盘上供后续读取使用

1)读写Tensor——save函数和load函数

使用save函数load函数分别存储和读取Tensor

2)读写模型——save函数和load_state_dict函数

state_dict

在PyTorch中,Module的可学习参数(即权重和偏差),模块模型包含在参数中(通过model.parameters()访问)。state_dict是一个从参数名称隐射到参数Tesnor的字典对象

注意,只有具有可学习参数的层(卷积层、线性层等)才有state_dict中的条目。优化器(optim)也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。

保存和加载模型

两种常见的方法

  1. 仅保存和加载模型参数(state_dict)

  保存:torch.save(model.state_dict(), PATH)

  加载:model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH))

  1. 保存和加载整个模型

保存:torch.save(model, PATH)

加载:model = torch.load(PATH)

你可能感兴趣的:(#,PyTorch,pytorch,深度学习,python)