【PyTorch】模型的存储和加载

1.  总体情况
说到模型的存储,主要有三个函数可以用:
(1) torch.save(): Model, Tensor和各个object的字典都会被存储起来
(2) torch.load(): 加载模型
(3) torch.nn.Model.load_state_dict():通过去序列化的state_dict来加载模型权重(Loads a model’s parameter dictionary using a deserialized state_dict
 

2. 什么是state_dict?
名字决定命运。这是一个记录state的dictionary,key是Model的层,value是这一层的可学习的参数,也就是那一层的权重和偏置项(提到这个想起前一篇文章的Tensor的重要属性,既然它是可学习的,那它的requires_grad属性一定是True)。

In PyTorch, the learnable parameters (i.e. weights and biases) of an torch.nn.Module model are contained in the model’s parameters (accessed with model.parameters()). A state_dict is simply a Python dictionary object that maps each layer to its parameter tensor. Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) have entries in the model’s state_dict. Optimizer objects (torch.optim) also have a state_dict, which contains information about the optimizer’s state, as well as the hyperparameters used.

Because state_dict objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers.

例如:
 

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(in_features=16*28*28, out_features=120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = TheModelClass()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Optimizer's state_dict")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

【PyTorch】模型的存储和加载_第1张图片
这是PyTorch很好的一个性质。你在建立了前向传播之后,pytorch会针对模型生成一个字典,这个字典记录了每一层的名字的和参数。
除了模型,optimizer也建立了这样的字典。
所以如果你想看模型的每个层都是什么(这个模型要继承nn.Module,并且定义好前向传播过程),就调用model.state_dict();
如果你想看优化器(优化器要通过torch.optim.SGD()或者其他优化器来指定)的参数设置,就调用optimizer.state_dict()

 

3. 真正在保存和加载模型的时候,有配对的函数:
对于仅保存state_dict()的方式,那保存和加载模型的方式为:
保存:torch.save(model.state_dict(), PATH)
加载:model.laod_state_dict(torch.load(PATH))
一般加载模型是在训练完成后用模型做测试,这时候加载模型记得要加上model.eval(),把模型切换到evaluation模式,这时候会调整dropout和bactch的模式。

对于保存和加载整个模型的情况:
torch.save(model, PATH)
model = torch.load(PATH)
可以看到,前面的model.load_state_dict()和这里的不同,前面的情况需要你先定义一个模型,然后再load_state_dict()
但是这里load整个模型,会把模型的定义一起load进来。完成了模型的定义和加载参数的两个过程。

4. 保存时候的其他问题
还可以保存checkpoint,多模型存储,GPU,CPU之间的切换
参考https://pytorch.org/tutorials/beginner/saving_loading_models.html
 

你可能感兴趣的:(【PyTorch】模型的存储和加载)