我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don’t expected)。我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed)。如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误。那么在这种情况下,该如何导入模型呢?
好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数值。我们使用model.state_dict()获得这个字典,之后就能利用参数名称来实现导入。
请看下面的一个例子。
我们先搭建一个小小的网络。
-
import torch
as t
-
from torch.nn
import Module
-
from torch
import nn
-
from torch.nn
import functional
as F
-
class Net(Module):
-
def __init__(self):
-
super(Net,self).__init__()
-
self.conv1 = nn.Conv2d(
3,
32,
3,
1)
-
self.conv2 = nn.Conv2d(
32,
3,
3,
1)
-
self.w = nn.Parameter(t.randn(
3,
10))
-
for p
in self.children():
-
nn.init.xavier_normal_(p.weight.data)
-
nn.init.constant_(p.bias.data,
0)
-
def forward(self, x):
-
out = self.conv1(x)
-
out = self.conv2(x)
-
-
out = F.avg_pool2d(out,(out.shape[
2],out.shape[
3]))
-
out = F.linear(out,weight=self.w)
-
return out
-
然后我们保存这个网络的初始值。
-
model = Net()
-
t.save(model.state_dict(),
'xxx.pth')
现在我们将Net修改一下,多加几个卷积层,但并不加入到forward中,仅仅出于少些几行的目的。
-
import torch
as t
-
from torch.nn
import Module
-
from torch
import nn
-
from torch.nn
import functional
as F
-
-
-
class Net(Module):
-
def __init__(self):
-
super(Net, self).__init__()
-
self.conv1 = nn.Conv2d(
3,
32,
3,
1)
-
self.conv2 = nn.Conv2d(
32,
3,
3,
1)
-
self.conv3 = nn.Conv2d(
3,
64,
3,
1)
-
self.conv4 = nn.Conv2d(
64,
32,
3,
1)
-
for p
in self.children():
-
nn.init.xavier_normal_(p.weight.data)
-
nn.init.constant_(p.bias.data,
0)
-
-
self.w = nn.Parameter(t.randn(
3,
10))
-
def forward(self, x):
-
out = self.conv1(x)
-
out = self.conv2(x)
-
-
out = F.avg_pool2d(out, (out.shape[
2], out.shape[
3]))
-
out = F.linear(out, weight=self.w)
-
return out
我们现在试着导入之前保存的模型参数。
-
path =
'xxx.pth'
-
model = Net()
-
model.load_state_dict(t.load(path))
-
-
'''
-
RuntimeError: Error(s) in loading state_dict for Net:
-
Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias".
-
'''
出现了没有在模型文件中找到error中的关键字的错误。
现在我们这样导入模型
-
path =
'xxx.pth'
-
model = Net()
-
save_model = t.load(path)
-
model_dict = model.state_dict()
-
state_dict = {k:v
for k,v
in save_model.items()
if k
in model_dict.keys()}
-
print(state_dict.keys())
# dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
-
model_dict.update(state_dict)
-
model.load_state_dict(model_dict)
看看上面的代码,很容易弄明白。其中model_dict.update的作用是更新代码中搭建的模型参数字典。为啥更新我其实并不清楚,但这一步骤是必须的,否则还会报错。
为了弄清楚为什么要更新model_dict,我们不妨分别输出state_dict和model_dict的关键值看一看。
-
for k
in state_dict.keys():
-
print(k)
-
-
'''
-
w
-
conv1.weight
-
conv1.bias
-
conv2.weight
-
conv2.bias
-
-
'''
-
for k
in model_dict.keys():
-
print(k)
-
-
'''
-
w
-
conv1.weight
-
conv1.bias
-
conv2.weight
-
conv2.bias
-
conv3.weight
-
conv3.bias
-
conv4.weight
-
conv4.bias
-
'''
这个结果也是预料之中的,所以我猜测,update之后,model_dict和state_dict中具有相同键的值已经同步了。updata的目的就是使model_dict带有state_dict中都具有的那一部分参数的值,对于model_dict中有的,但是save_dict中没有的参数,值不改变,参数仍然使用初始值。