PyTorch-Tutorials【pytorch官方教程中英文详解】- 8 Save and Load Model

在文章PyTorch-Tutorials【pytorch官方教程中英文详解】- 7 Optimization中介绍了模型如何优化,包括构造损失函数和优化器等,接下来看看如何保存已经优化的模型,以及如何载入保存的模型。

原文链接:Save and Load the Model — PyTorch Tutorials 1.10.1+cu102 documentation

SAVE AND LOAD THE MODEL

In this section we will look at how to persist model state with saving, loading and running model predictions.

【在本节中,我们将看看如何通过保存、加载和运行模型预测来持久化模型状态。】

import torch
import torchvision.models as models

1 Saving and Loading Model Weights

PyTorch models store the learned parameters in an internal state dictionary, called state_dict. These can be persisted via the torch.save method:

【PyTorch模型将学习到的参数存储在一个名为state_dict的内部状态字典中。这些可以通过torch.save方法:】

model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

To load model weights, you need to create an instance of the same model first, and then load the parameters using load_state_dict() method.

【要加载模型权重,您需要首先创建同一个模型的实例,然后使用load_state_dict()方法加载参数。】

model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

NOTE

  • be sure to call model.eval() method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results.
  • 在推断将退出和批处理规范化层设置为求值模式之前,请确保调用model.eval()方法。如果不这样做,将产生不一致的推理结果。】

2 Saving and Loading Models with Shapes

When loading model weights, we needed to instantiate the model class first, because the class defines the structure of a network. We might want to save the structure of this class together with the model, in which case we can pass model (and not model.state_dict()) to the saving function:

【当加载模型权重时,我们需要首先实例化模型类,因为类定义了网络的结构。我们可能希望将这个类的结构与模型一起保存,在这种情况下,我们可以将model(而不是model.state_dict())传递给保存函数:

torch.save(model, 'model.pth')

We can then load the model like this:

【然后我们可以像这样加载模型:】

model = torch.load('model.pth')

NOTE

  • This approach uses Python pickle module when serializing the model, thus it relies on the actual class definition to be available when loading the model.
  • 【这种方法在序列化模型时使用Python pickle模块,因此它依赖于在加载模型时可用的实际类定义。】

 3 Related Tutorials

Saving and Loading a General Checkpoint in PyTorch

恭喜!整个【pytorch官方教程中英文详解】系列更新完毕,一共八节。

说明:记录学习笔记,如果错误欢迎指正!写文章不易,转载请联系我。

你可能感兴趣的:(DL框架,AI,笔记,pytorch,人工智能,python)