PyTorch模型保存加载

PyTorch保存模型的语句是这样的:

#将模型参数保存到path路径下
torch.save(model.state_dict(), path)

加载是这样的:

model.load_state_dict(torch.load(path))

下面我们将其拆开逐句介绍

1.torch.save()和torch.load()

save函数是PyTorch的存储函数,load函数则是读取函数。save函数可以将各种对象保存至磁盘,包括张量,列表,ndarray,字典,模型等;而相应地,load函数将保存在磁盘上的对象读取出来。

用法:

torch.save(保存对象, 保存路径)
torch.load(文件路径)

应用举例:

保存张量

In [3]: a = torch.ones(3)                                                       

In [4]: a                                                                       
Out[4]: tensor([1., 1., 1.])

In [5]: torch.save(a, './a.pth')          # 保存Tensor               

In [6]: a_load = torch.load('./a.pth')    # 读取Tensor

In [7]: a_load                                                                  
Out[7]: tensor([1., 1., 1.])

保存字典

In [11]: b = {k:v for v,k in enumerate('abc',1)}                                

In [12]: b                                                                      
Out[12]: {'a': 1, 'b': 2, 'c': 3}

In [13]: torch.save(b, './b.rar')                        

In [14]: torch.load('./b.rar')                           
Out[14]: {'a': 1, 'b': 2, 'c': 3}

可以看出,保存和读取非常方便。这里需要注意的是文件的命名,命名必须要有扩展名,扩展名可以为‘xxx.pt’,‘xxx.pth’,‘xxx.pkl’,‘xxx.rar’,'xxx.tar'等形式。

2.model.state_dict()

在PyTorch中,state_dict是一个从参数名称映射到参数Tesnor的字典对象

In [15]: class MLP(nn.Module): 
    ...:     def __init__(self): 
    ...:         super(MLP, self).__init__() 
    ...:         self.hidden = nn.Linear(3, 2) 
    ...:         self.act = nn.ReLU() 
    ...:         self.output = nn.Linear(2, 1) 
    ...:  
    ...:     def forward(self, x): 
    ...:         a = self.act(self.hidden(x)) 
    ...:         return self.output(a) 
    ...:                                                                        

In [16]: net = MLP()                                                            

In [17]: net.state_dict()                                                       
Out[17]: 
OrderedDict([('hidden.weight', tensor([[ 0.4839,  0.0254,  0.5642],
                      [-0.5596,  0.2602, -0.5235]])),
             ('hidden.bias', tensor([-0.4986, -0.5426])),
             ('output.weight', tensor([[0.0967, 0.4980]])),
             ('output.bias', tensor([-0.4520]))])

可以看出,state_dict()返回的是一个有序字典,该字典的键即为模型定义中有可学习参数的层的名称+weight或+bias,值则对应相应的权重或偏差,无参数的层则不在其中。

state_dict()返回模型可学习参数的键值对,那么问题来了:model.parameters()不也是模型可学习参数的访问方式吗?它们有什么不同?

我们再来对比了一下常用的model.parameters(),在训练时,我们常常将此语句放在优化器中,表示要优化学习的模型参数。

In [18]: net.parameters()                                                       
Out[18]: 

In [19]: list(net.parameters())                                                 
Out[19]: 
[Parameter containing:
 tensor([[ 0.4839,  0.0254,  0.5642],
         [-0.5596,  0.2602, -0.5235]], requires_grad=True),
 Parameter containing:
 tensor([-0.4986, -0.5426], requires_grad=True),
 Parameter containing:
 tensor([[0.0967, 0.4980]], requires_grad=True),
 Parameter containing:
 tensor([-0.4520], requires_grad=True)]
#######################################################################
In [20]: net.named_parameters()                                                 
Out[20]: 

In [21]: list(net.named_parameters())                                           
Out[21]: 
[('hidden.weight', Parameter containing:
  tensor([[ 0.4839,  0.0254,  0.5642],
          [-0.5596,  0.2602, -0.5235]], requires_grad=True)),
 ('hidden.bias', Parameter containing:
  tensor([-0.4986, -0.5426], requires_grad=True)),
 ('output.weight', Parameter containing:
  tensor([[0.0967, 0.4980]], requires_grad=True)),
 ('output.bias', Parameter containing:
  tensor([-0.4520], requires_grad=True))]

可以看出,model.parameters()是一个生成器,每个参数张量都是一个参数容器,它的对象是各个参数Tensor。这里简单的提一下model.named_parameters(),它和model.parameters很像,都返回一个可迭代对象,但从上例可以看出它多返回一个参数名称,这样有利于访问和初始化或修改参数。
我们在用优化器优化参数时,优化对象是纯参数,所以用model.parameters():

In [22]: optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)  

In [23]: optimizer.state_dict()                                                 
Out[23]: 
{'state': {},
 'param_groups': [{'lr': 0.001,
   'momentum': 0.9,
   'dampening': 0,
   'weight_decay': 0,
   'nesterov': False,
   'params': [139638712861752,
    139640593827520,
    139640585217080,
    139640585217296]}]}

除了模型中可学习参数的层(卷积层、线性层等)有state_dict,优化器也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。

3.model.load_state_dict()

这是模型加载state_dict的语句,也就是说,它的输入是一个state_dict,也就是一个字典。模型定义好并且实例化后会自动进行初始化,上面的例子中我们定义的模型MLP在实例化以后显示的模型参数都是自动初始化后的随机数。
在训练模型或者迁移学习中我们会使用已经训练好的参数来加速训练过程,这时候就用load_state_dict()语句加载训练好的参数并将其覆盖在初始化参数上,也就是说执行过此语句后,加载的参数将代替原有的模型参数。

既然加载的是一个字典,那么需要注意的就是字典的键一定要相同才能进行覆盖,比如加载的字典中的'hidden.weight'只能覆盖当前模型的'hidden.weight',如果键不同,则不能实现有效覆盖操作。键相同而值的shape不同,则会将新的键值对覆盖原来的键值对,这样在训练时会报错。所以我们在加载前一般会进行数据筛选,筛选是对字典的键进行对比来操作的:

pretrained_dict = torch.load(log_dir)  # 加载参数字典
model_state_dict = model.state_dict()  # 加载模型当前状态字典
pretrained_dict_1 = {k:v for k,v in pretrained_dict.items() if k in model_state_dict}  # 过滤出模型当前状态字典中没有的键值对
model_state_dict.update(pretrained_dict_1)  # 用筛选出的参数键值对更新model_state_dict变量
model.load_state_dict(model_state_dict)  # 将筛选出的参数键值对加载到模型当前状态字典中

以上代码简单的对预训练参数进行了过滤和筛选,主要是通过第3条语句粗略的过滤了键值对信息,进行筛选后要用Python更新字典的方法update()来对模型当前字典进行更新,update()方法将pretrained_dict_1中的键值对添加到model_state_dict中,若pretrained_dict_1中的键名和model_state_dict中的键名相同,则覆盖之;若不同,则作为新增键值对添加到model_state_dict中。显然,这里需要的是将pretrained_dict_1中的键值对覆盖model_state_dict的相应键值对,所以对应的键的名称必须相同,所以第3条语句中按键名称进行筛选,过滤出当前模型字典中没有的键值对。否则会报错。

如果想要细粒度过滤或更改某些参数的维度,如进行卷积核参数维度的调整,假如预训练参数里conv1有256个卷积核,而当前模型只需要200个卷积核,那么可以采用类似以下语句直接对字典进行更改:

pretrained_dict['conv1.weight'] = pretrained_dict['conv1.weight'][:200,:,:,:]   # 假设保留前200个卷积核

你可能感兴趣的:(pytorch)