unexpected key "module.conv1_1.weight" in state_dict

torch加载模型时出现如下错误

异常位置

save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
network.load_state_dict(torch.load(save_path))

异常信息

File "/data/Muyi/Github/EnlightenGAN/models/base_model.py", line 54, in load_network
network.load_state_dict(torch.load(save_path))
File "/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 522, in load_state_dict
.format(name))
KeyError: 'unexpected key "module.conv1_1.weight" in state_dict'

异常原因

最终原因在此:训练时使用GPU,使用了torch.nn.DataParallel(),而此时预测没有使用GPU,即没有使用此模块导致上述异常
if len(gpu_ids) > 0:
    netG.cuda(device=gpu_ids[0])
    netG = torch.nn.DataParallel(netG, gpu_ids)

解决方案

1):加上torch.nn.DataParallel()模块,类似我的问题只需要使用GPU即可正常运行
model = torch.nn.DataParallel(model)
cudnn.benchmark = True
2):将原来字典中module.删除掉
network.load_state_dict({k.replace('module.',''):v for k,v in torch.load(save_path).items()})
更改原来代码如下,即可在CPU/GPU下都正常运行
if len(self.gpu_ids):
    network.load_state_dict(torch.load(save_path))
else:
    network.load_state_dict({k.replace('module.',''):v for k,v in torch.load(save_path).items()})

你可能感兴趣的:(Error)