在文章PyTorch-Tutorials【pytorch官方教程中英文详解】- 7 Optimization中介绍了模型如何优化,包括构造损失函数和优化器等,接下来看看如何保存已经优化的模型,以及如何载入保存的模型。
原文链接:Save and Load the Model — PyTorch Tutorials 1.10.1+cu102 documentation
SAVE AND LOAD THE MODEL
In this section we will look at how to persist model state with saving, loading and running model predictions.
【在本节中,我们将看看如何通过保存、加载和运行模型预测来持久化模型状态。】
import torch
import torchvision.models as models
PyTorch models store the learned parameters in an internal state dictionary, called state_dict. These can be persisted via the torch.save method:
【PyTorch模型将学习到的参数存储在一个名为state_dict的内部状态字典中。这些可以通过torch.save方法:】
model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')
To load model weights, you need to create an instance of the same model first, and then load the parameters using load_state_dict() method.
【要加载模型权重,您需要首先创建同一个模型的实例,然后使用load_state_dict()方法加载参数。】
model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
NOTE
When loading model weights, we needed to instantiate the model class first, because the class defines the structure of a network. We might want to save the structure of this class together with the model, in which case we can pass model (and not model.state_dict()) to the saving function:
【当加载模型权重时,我们需要首先实例化模型类,因为类定义了网络的结构。我们可能希望将这个类的结构与模型一起保存,在这种情况下,我们可以将model(而不是model.state_dict())传递给保存函数:】
torch.save(model, 'model.pth')
We can then load the model like this:
【然后我们可以像这样加载模型:】
model = torch.load('model.pth')
NOTE
Saving and Loading a General Checkpoint in PyTorch
恭喜!整个【pytorch官方教程中英文详解】系列更新完毕,一共八节。
说明:记录学习笔记,如果错误欢迎指正!写文章不易,转载请联系我。