PyTorch中一般约定是使用.pt或.pth文件扩展名保存模型,通过torch.save保存模型,通过torch.load加载模型。torch.save和torch.load函数的实现在torch/serialization.py文件中。
这里以LeNet5模型为例进行说明。LeNet5的介绍过程参考:https://blog.csdn.net/fengbingchun/article/details/125462001
你应该保存模型的参数,而不是模型本身(you should keep the parameters of the model, not the model itself)。保存模型进行推理(inference)时,只需要保存训练模型的学习参数即可。使用torch.save函数保存模型的state_dict将为你以后恢复模型提供最大的灵活性,这就是为什么它是保存模型的推荐方法。
torch.save函数有两种保存方式:一种是保存整个模型,此时模型的type应该为继承自nn.Module的类,这里则为类LeNet5;另一种是仅保存模型的参数,此时模型的type应该为有序字典即类OrderedDict。
torch.save函数将序列化的对象保存到磁盘。此函数使用Python的pickle进行序列化。通过pickle可以保存各种对象的模型、张量和字典。
pickle的介绍参考参考:https://blog.csdn.net/fengbingchun/article/details/125584682
torch.load函数使用pickle的unpickling将pickle对象文件反序列化到内存中。
torch.nn.Module的load_state_dict函数:使用反序列化的state_dict加载模型的参数字典。
torch.nn.Module的state_dict函数:在PyTorch中,torch.nn.Module模型的可学习参数(即weights和biases)包含在模型的参数中(通过model.parameters函数访问)。state_dict只是一个Python字典对象,它将每一层映射到其参数张量(tensor)。注意:只有具有可学习参数的层(卷积层,线性层等)和注册缓冲区(batchnorm’s running_mean)在模型的state_dict中有条目( Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm’s running_mean) have entries in the model’s state_dict)。优化器对象(torch.optim)也有一个state_dict,其中包含有关优化器状态的信息,以及使用的超参数。因为state_dict对象是Python字典,所以它们可以很容易地保存、更新、更改和恢复。
注意:
(1).在运行推理之前,你必须调用model.eval函数将dropout和批量标准化层(batch normalization layers)设置为评估模式。不这样做会产生不一致的推理结果。
(2).load_state_dict函数采用字典对象,而不是保存对象的路径(load_state_dict function takes a dictionary object, NOT a path to a saved object)。这意味着你必须在将保存的state_dict传递给load_state_dict函数之前对其进行反序列化(you must deserialize the saved state_dict before you pass it to the load_state_dict function)。
(3).如果你只打算保留性能最佳的模型,不要忘记best_model_state = model.state_dict()返回对状态的引用而不是其副本(not its copy)。你必须序列化best_model_state或使用best_model_state = deepcopy(model.state_dict()) 否则你的best_model_state将通过后续训练迭代不断更新。结果,最终的模型状态将是过拟合模型的状态。
(4).推荐:torch.save(model.state_dict(), PATH)/model.load_state_dict(torch.load(PATH));
不推荐:torch.save(model, PATH)/model = torch.load(PATH):保存整个模型。
以下是测试的代码段:
def save_load_model(model):
'''saving and loading models'''
model.load_state_dict(torch.load("../../data/Lenet-5.pth")) # 加载模型
model.eval() # 将网络设置为评估模式
# state_dict:返回一个字典,保存着module的所有状态,参数和persistent buffers都会包含在字典中,字典的key就是参数和buffer的names
print("model state dict keys:", model.state_dict().keys())
print("model type:", type(model)) # model type:
print("model state dict type:", type(model.state_dict())) # model state dict type:
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)
print_state_dict(model.state_dict(), optimizer.state_dict())
torch.save(model, "../../data/Lenet-5_all.pth") # 保存整个模型
torch.save(model.state_dict(), "../../data/Lenet-5_parameters.pth") # 推荐:仅保存训练模型的参数,为以后恢复模型提供最大的灵活性
保存一般检查点(checkpoint)用于推理或恢复训练时,你保存的不仅仅是模型的state_dict,保存优化器的state_dict也很重要,因为它包含随着模型训练而更新的缓冲区和参数(buffers and parameters)。你可能还想要保存已训练的epoch编号、最新记录的训练损失、以及外部的torch.nn.Embedding层等。这样的checkpoint通常比单独的模型大2至3倍。
要保存多个组件,需要将它们组织在字典中并使用torch.save序列化字典。一个常见的PyTorch约定是使用.tar文件扩展名保存这些checkpoint。
以下是测试的代码段:
def save_load_checkpoint(model):
'''saving & loading a general checkpoint for inference and/or resuming training'''
path = "../../data/Lenet-5_parameters.tar"
model.load_state_dict(torch.load("../../data/Lenet-5.pth")) # 加载模型
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)
torch.save({
'epoch': 5,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}, path)
checkpoint = torch.load(path)
model2 = LeNet5(n_classes=10).to('cpu') # 实例化一个LeNet5网络对象
model2.load_state_dict(checkpoint['model_state_dict'])
optimizer2 = torch.optim.SGD(params=model2.parameters(), lr=0.1)
optimizer2.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
print("epoch:", epoch)
model.eval() # 将网络设置为评估模式
#model.train() # 恢复训练,将网络设置为训练模式
print_state_dict(model2.state_dict(), optimizer2.state_dict())
保存由多个torch.nn.Modules组成的模型时,例如GAN、sequence-to-sequence model或模型集合,需遵循与保存checkpoint时相同的方法。即保存每个模型的state_dict和相应优化器的dictionary。
以下是测试的代码段:
def save_load_multiple_models():
'''saving multiple models in one file'''
path1 = "../../data/Lenet-5.pth"
path2 = "../../data/Lenet-5_parameters_mul.tar"
model1 = LeNet5(n_classes=10).to('cpu') # 实例化一个LeNet5网络对象
model1.load_state_dict(torch.load(path1)) # 加载模型
optimizer1 = torch.optim.Adam(params=model1.parameters(), lr=0.001)
model2 = LeNet5(n_classes=10).to('cpu') # 实例化一个LeNet5网络对象
model2.load_state_dict(torch.load(path1)) # 加载模型
optimizer2 = torch.optim.SGD(params=model2.parameters(), lr=0.1)
torch.save({
'epoch': 100,
'model1_state_dict': model1.state_dict(),
'model2_state_dict': model2.state_dict(),
'optimizer1_state_dict': optimizer1.state_dict(),
'optimizer2_state_dict': optimizer2.state_dict(),
}, path2)
checkpoint = torch.load(path2)
modelA = LeNet5(n_classes=10).to('cpu') # 实例化一个LeNet5网络对象
modelA.load_state_dict(checkpoint['model1_state_dict'])
optimizerA = torch.optim.SGD(params=modelA.parameters(), lr=0.1)
optimizerA.load_state_dict(checkpoint['optimizer1_state_dict'])
modelB = LeNet5(n_classes=10).to('cpu') # 实例化一个LeNet5网络对象
modelB.load_state_dict(checkpoint['model2_state_dict'])
optimizerB = torch.optim.Adam(params=modelB.parameters(), lr=0.01)
optimizerB.load_state_dict(checkpoint['optimizer2_state_dict'])
epoch = checkpoint['epoch']
print("epoch:", epoch)
modelA.eval() # 将网络设置为评估模式
#modelA.train() # 恢复训练,将网络设置为训练模式
#modelB.eval() # 将网络设置为评估模式
modelB.train() # 恢复训练,将网络设置为训练模式
print_state_dict(modelA.state_dict(), optimizerA.state_dict())
print_state_dict(modelB.state_dict(), optimizerB.state_dict())
部分加载模型或加载部分模型(partially loading a model or loading a partial model)是迁移学习或训练新的复杂模型时的常见场景。利用经过训练的参数,即使只有少数可用,也将有助于热启动(warmstart)训练过程,并有望帮助你的模型从头开始训练更快地收敛。
PyTorch中其它保存模型的方法:
(1).torch.package是一种以独立、稳定的格式打包PyTorch模型的新方法。它包含模型参数、元数据(metadata)及架构。此外,torch.package添加了对创建包含任意PyTorch code的密封包(hermetic package)的支持,这意味着你可能会使用它来打包你想要的任何东西,例如PyTorch DataLoaders、Datasets等。
(2).使用经过训练的模型进行推理的另一种方法是使用TorchScript,它是PyTorch模型的中间表示,可以在Python以及C++等环境中运行。TorchScript实际上是用于扩展推理和部署的推荐模型格式。注意:使用TorchScript格式,你将能够加载导出的模型并进行推理,而无需定义模型类。
以上文字描述主要翻译自:https://pytorch.org/tutorials/beginner/saving_loading_models.html
GitHub:https://github.com/fengbingchun/PyTorch_Test