用PyTorch加载已经保存好的模型参数文件时遇到一个bug:
Unexpected key(s) in state_dict: “module.features. …”.,Expected “.features…”
意思是从本地文件中加载模型的state_dict时,state_dict的key值不匹配。查了一些资料后,发现是PyTorch多gpu保存的问题,导致保存下来的state_dict中的key比原来都多了一个module,因此出现了上述这个问题。下面简单验证一下。
import torch
import torch.nn as nn
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.conv = nn.Conv2d(1, 3, 3, 1, 1, bias=False)
self.model = nn.Sequential(
nn.Conv2d(3, 1, 1, 1, 1, bias=False),
nn.BatchNorm2d(1),
nn.ReLU(inplace=True),
)
def forward(self, x):
self.model(self.conv(x))
先查看直接定义后的结果:
model = net()
print(model)
上述代码的结果为:
net(
(conv): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(model): Sequential(
(0): Conv2d(3, 1, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(1, eps=1e-5, momentum=0.1, affine=True, tracking_running_stats=True)
(2): ReLU(inplace)
)
)
再分别查看将模型并行和转移到GPU上的结果:
先单独查看将模型转移到GPU上的结果:
if torch.cuda.is_available():
model = model.cuda(0)
print(model)
上述代码的结果和之前的结果相同,在此就不贴结果了。说明将模型转移到GPU上的操作是没有影响的。
再来看模型并行的结果:
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
print(model)
model = model.cuda()
print(model)
上述代码两个print的结果相同,均为:
DataParallel(
(module):net(
(conv): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(model): Sequential(
(0): Conv2d(3, 1, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(1, eps=1e-5, momentum=0.1, affine=True, tracking_running_stats=True)
(2): ReLU(inplace)
)
)
)
由上述代码可知,模型并行使原模型外被包裹了一层module,会影响到模型的保存和加载:
先保存并行的模型:
model = net()
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model).cuda()
torch.save(model.state_dict(), 'demo.pth')
若直接加载:
model2 = net()
model2.load_state_dict(torch.load('demo.pth'))
此时就会遇到文章开头说的错误。解决这个问题有两种方法:
先将模型并行,再加载权重:
model2 = net()
model2 = nn.DataParallel(model2).cuda()
model2.load_state_dict(torch.load('demo.pth'))
去掉本地权重文件key值中的module:
model2 = net()
model2.load_state_dict({k.replace('module.', ''):v for k, v in torch.load('demo.pth').items()})
model2 = nn.DataParallel(model2).cuda()
因为模型始终是要并行的,所以推荐方法一,不增加额外的代码 ,注意一下模型并行和加载权重文件的顺序即可。