Pytorch:模型的保存与加载

Pytorch模型保存与加载,并在加载的模型基础上继续训练

  • 1.基本语句
        • 1.1 保存参数
        • 1.2 加载参数
  • 2. 语句分析
        • 2.1 torch.save()和torch.load()
        • 2.2 model.state_dict()
        • 2.3 model.load_state_dict()
  • 3. state_dict()和model.parameters()

Pytorch保存模型保存的是模型参数

1.基本语句

1.1 保存参数

一般地,采用一条语句即可保存参数:

torch.save(model.state_dict(), path)

其中model指定义的模型实例变量, path是保存参数的路径,如 path=’./model.pth’ , path=’./model.tar’, path=’./model.pkl’, 保存参数的文件一定要有后缀扩展名。
特别地,如果还想保存某一次训练采用的优化器、epochs等信息,可将这些信息组合起来构成一个字典,然后将字典保存起来:

state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
torch.save(state, path)

保存模型的时机可以是不同的epoch(注意路径需要不同),也可以是识别率最大的当前模型

1.2 加载参数

第一种情况,也只需要一句即可加载模型:

model.load_state_dict(torch.load(path))

第二种以字典形式保存的方法,加载方式如下:

checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint(['epoch'])

需要注意的是,只保存参数的方法在加载的时候要事先定义好跟原模型一致的模型

2. 语句分析

2.1 torch.save()和torch.load()

顾名思义,save函数是Pytorch的存储函数,load函数则是读取函数。save函数可以将各种对象保存至磁盘,包括张量,列表,ndarray,字典,模型等;而相应地,load函数将保存在磁盘上的对象读取出来。

用法:

torch.save(保存对象, 保存路径)
torch.load(文件路径)

需要注意的是文件的命名,命名必须要有扩展名,扩展名可以为‘xxx.pt’,‘xxx.pth’,‘xxx.pkl’,‘xxx.rar’等形式。

2.2 model.state_dict()

在Pytorch中,state_dict是一个从参数名称隐射到参数Tesnor的字典对象。

class MLP(nn.Module): 
    def __init__(self): 
    	super(MLP, self).__init__() 
	    self.hidden = nn.Linear(3, 2) 
    	self.act = nn.ReLU() 
    	self.output = nn.Linear(2, 1) 
    
   	def forward(self, x): 
    	a = self.act(self.hidden(x)) 
    	return self.output(a)   
    	
net = MLP()
print(net.state_dict())    
# 输出如下                                                
OrderedDict([('hidden.weight', tensor([[ 0.4839,  0.0254,  0.5642],
                      [-0.5596,  0.2602, -0.5235]])),
             ('hidden.bias', tensor([-0.4986, -0.5426])),
             ('output.weight', tensor([[0.0967, 0.4980]])),
             ('output.bias', tensor([-0.4520]))])

可以看出,state_dict()返回的是一个有序字典,该字典的键即为模型定义中有可学习参数的层的名称+weight或+bias,值则对应相应的权重或偏差,无参数的层则不在其中。

2.3 model.load_state_dict()

这是模型加载state_dict的语句,也就是说,它的输入是一个state_dict,也就是一个字典。模型定义好并且实例化后会自动进行初始化,上面的例子中我们定义的模型MLP在实例化以后显示的模型参数都是自动初始化后的随机数。
在训练模型或者迁移学习中我们会使用已经训练好的参数来加速训练过程,这时候就用load_state_dict()语句加载训练好的参数并将其覆盖在初始化参数上,也就是说执行过此语句后,加载的参数将代替原有的模型参数。
既然加载的是一个字典,那么需要注意的就是字典的键一定要相同才能进行覆盖,比如加载的字典中的’hidden.weight’只能覆盖当前模型的’hidden.weight’,如果键不同,则不能实现有效覆盖操作。键相同而值的shape不同,则会将新的键值对覆盖原来的键值对,这样在训练时会报错。所以我们在加载前一般会进行数据筛选,筛选是对字典的键进行对比来操作的:

pretrained_dict = torch.load(log_dir)  # 加载参数字典
model_state_dict = model.state_dict()  # 加载模型当前状态字典
pretrained_dict_1 = {k:v for k,v in pretrained_dict.items() if k in model_state_dict}  # 过滤出模型当前状态字典中没有的键值对
model_state_dict.update(pretrained_dict_1)  # 用筛选出的参数键值对更新model_state_dict变量
model.load_state_dict(model_state_dict)  # 将筛选出的参数键值对加载到模型当前状态字典中

以上代码简单的对预训练参数进行了过滤和筛选,主要是通过第3条语句粗略的过滤了键值对信息,进行筛选后要用Python更新字典的方法update()来对模型当前字典进行更新,update()方法将pretrained_dict_1中的键值对添加到model_state_dict中,若pretrained_dict_1中的键名和model_state_dict中的键名相同,则覆盖之;若不同,则作为新增键值对添加到model_state_dict中。显然,这里需要的是将pretrained_dict_1中的键值对覆盖model_state_dict的相应键值对,所以对应的键的名称必须相同,所以第3条语句中按键名称进行筛选,过滤出当前模型字典中没有的键值对。否则会报错。
如果想要细粒度过滤或更改某些参数的维度,如进行卷积核参数维度的调整,假如预训练参数里conv1有256个卷积核,而当前模型只需要200个卷积核,那么可以采用类似以下语句直接对字典进行更改:

pretrained_dict['conv1.weight'] = pretrained_dict['conv1.weight'][:200,:,:,:]   # 假设保留前200个卷积核

3. state_dict()和model.parameters()

state_dict()返回模型可学习参数的键值对,那么问题来了:model.parameters()不也是模型可学习参数的访问方式吗?它们有什么不同?
我们再来对比了一下常用的model.parameters(),在训练时,我们常常将此语句放在优化器中,表示要优化学习的模型参数。关于model.parameters(), model.named_parameters(), model.children(), model.named_children(), model.modules(), model.named_modules()请看我的这篇博客

In [18]: net.parameters()                                                       
Out[18]: <generator object Module.parameters at 0x7f009b78acf0>

In [19]: list(net.parameters())                                                 
Out[19]: 
[Parameter containing:
 tensor([[ 0.4839,  0.0254,  0.5642],
         [-0.5596,  0.2602, -0.5235]], requires_grad=True),
 Parameter containing:
 tensor([-0.4986, -0.5426], requires_grad=True),
 Parameter containing:
 tensor([[0.0967, 0.4980]], requires_grad=True),
 Parameter containing:
 tensor([-0.4520], requires_grad=True)]
#######################################################################
In [20]: net.named_parameters()                                                 
Out[20]: <generator object Module.named_parameters at 0x7f00a009a390>

In [21]: list(net.named_parameters())                                           
Out[21]: 
[('hidden.weight', Parameter containing:
  tensor([[ 0.4839,  0.0254,  0.5642],
          [-0.5596,  0.2602, -0.5235]], requires_grad=True)),
 ('hidden.bias', Parameter containing:
  tensor([-0.4986, -0.5426], requires_grad=True)),
 ('output.weight', Parameter containing:
  tensor([[0.0967, 0.4980]], requires_grad=True)),
 ('output.bias', Parameter containing:
  tensor([-0.4520], requires_grad=True))]

可以看出,model.parameters()是一个生成器,每个参数张量都是一个参数容器,它的对象是各个参数Tensor。这里简单的提一下model.named_parameters(),它和model.parameters很像,都返回一个可迭代对象,但从上例可以看出它多返回一个参数名称,这样有利于访问和初始化或修改参数
我们在用优化器优化参数时,优化对象是纯参数,所以用model.parameters():

In [22]: optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)  

In [23]: optimizer.state_dict()                                                 
Out[23]: 
{'state': {},
 'param_groups': [{'lr': 0.001,
   'momentum': 0.9,
   'dampening': 0,
   'weight_decay': 0,
   'nesterov': False,
   'params': [139638712861752,
    139640593827520,
    139640585217080,
    139640585217296]}]}

除了模型中可学习参数的层(卷积层、线性层等)有state_dict,优化器也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。

你可能感兴趣的:(Pytorch)