Pytorch“ntimeError: Error(s) in loading state_dict for DataParallel: Missing key(s) in state_dict:“

对训练好的模型进行测试,得到测试样本。通过下面的程序将模型参数导入到新建的模型中。
由于我们直接用torch.load()存储的模型信息会比较大,因此我们可以只存储参数信息,进行测试时再将参数信息导入到模型中(一定要与保存的模型大小和内容相同)会提高效率。

torch.save(model.state_dict(),'hscnn_5layer_dim10_276.pkl')
#不直接用torch.save(mode,'hscnn_5layer_dim10_276.pkl')

下载保存的模型参数到测试程序:

model_path = './models/hscnn_5layer_dim10_276.pkl'
img_path = './test_imgs/'
result_path = './test_results1/'
var_name = 'rad'

save_point = torch.load(model_path)
model_param = save_point['state_dict']
print(model_param.keys())
model = resblock(conv_relu_res_relu_block,16,3,31)
model = nn.DataParallel(model)
model.load_state_dict(model_param)

model = model.cuda()
model.eval()

运行上面的程序在“model.load_state_dict(model_param)”位置会出现错误:

 File "E:\install\Anaconda3\envs\pytorch_GPU\lib\site-packages\torch\nn\modules\module.py", line 1052, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DataParallel:
	Missing key(s) in state_dict: "module.input_conv.weight", "module.input_conv.bias", "module.conv_seq.0.conv1.weight", "module.conv_seq.0.conv1.bias", "module.conv_seq.0.conv2.weight", "module.conv_seq.0.conv2.bias", "module.conv_seq.1.conv1.weight", "module.conv_seq.1.conv1.bias", "module.conv_seq.1.conv2.weight", "module.conv_seq.1.conv2.bias", "module.conv_seq.2.conv1.weight", "module.conv_seq.2.conv1.bias", "module.conv_seq.2.conv2.weight", "module.conv_seq.2.conv2.bias", "module.conv_seq.3.conv1.weight", "module.conv_seq.3.conv1.bias", "module.conv_seq.3.conv2.weight", "module.conv_seq.3.conv2.bias", "module.conv_seq.4.conv1.weight", "module.conv_seq.4.conv1.bias", "module.conv_seq.4.conv2.weight", "module.conv_seq.4.conv2.bias", "module.conv_seq.5.conv1.weight", "module.conv_seq.5.conv1.bias", "module.conv_seq.5.conv2.weight", "module.conv_seq.5.conv2.bias", "module.conv_seq.6.conv1.weight", "module.conv_seq.6.conv1.bias", "module.conv_seq.6.conv2.weight", "module.conv_seq.6.conv2.bias", "module.conv_seq.7.conv1.weight", "module.conv_seq.7.conv1.bias", "module.conv_seq.7.conv2.weight", "module.conv_seq.7.conv2.bias", "module.conv_seq.8.conv1.weight", "module.conv_seq.8.conv1.bias", "module.conv_seq.8.conv2.weight", "module.conv_seq.8.conv2.bias", "module.conv_seq.9.conv1.weight", "module.conv_seq.9.conv1.bias", "module.conv_seq.9.conv2.weight", "module.conv_seq.9.conv2.bias", "module.conv_seq.10.conv1.weight", "module.conv_seq.10.conv1.bias", "module.conv_seq.10.conv2.weight", "module.conv_seq.10.conv2.bias", "module.conv_seq.11.conv1.weight", "module.conv_seq.11.conv1.bias", "module.conv_seq.11.conv2.weight", "module.conv_seq.11.conv2.bias", "module.conv_seq.12.conv1.weight", "module.conv_seq.12.conv1.bias", "module.conv_seq.12.conv2.weight", "module.conv_seq.12.conv2.bias", "module.conv_seq.13.conv1.weight", "module.conv_seq.13.conv1.bias", "module.conv_seq.13.conv2.weight", "module.conv_seq.13.conv2.bias", "module.conv_seq.14.conv1.weight", "module.conv_seq.14.conv1.bias", "module.conv_seq.14.conv2.weight", "module.conv_seq.14.conv2.bias", "module.conv_seq.15.conv1.weight", "module.conv_seq.15.conv1.bias", "module.conv_seq.15.conv2.weight", "module.conv_seq.15.conv2.bias", "module.conv.weight", "module.conv.bias", "module.output_conv.weight", "module.output_conv.bias". 
	Unexpected key(s) in state_dict: "input_conv.weight", "input_conv.bias", "conv_seq.0.conv1.weight", "conv_seq.0.conv1.bias", "conv_seq.0.conv2.weight", "conv_seq.0.conv2.bias", "conv_seq.1.conv1.weight", "conv_seq.1.conv1.bias", "conv_seq.1.conv2.weight", "conv_seq.1.conv2.bias", "conv_seq.2.conv1.weight", "conv_seq.2.conv1.bias", "conv_seq.2.conv2.weight", "conv_seq.2.conv2.bias", "conv_seq.3.conv1.weight", "conv_seq.3.conv1.bias", "conv_seq.3.conv2.weight", "conv_seq.3.conv2.bias", "conv_seq.4.conv1.weight", "conv_seq.4.conv1.bias", "conv_seq.4.conv2.weight", "conv_seq.4.conv2.bias", "conv_seq.5.conv1.weight", "conv_seq.5.conv1.bias", "conv_seq.5.conv2.weight", "conv_seq.5.conv2.bias", "conv_seq.6.conv1.weight", "conv_seq.6.conv1.bias", "conv_seq.6.conv2.weight", "conv_seq.6.conv2.bias", "conv_seq.7.conv1.weight", "conv_seq.7.conv1.bias", "conv_seq.7.conv2.weight", "conv_seq.7.conv2.bias", "conv_seq.8.conv1.weight", "conv_seq.8.conv1.bias", "conv_seq.8.conv2.weight", "conv_seq.8.conv2.bias", "conv_seq.9.conv1.weight", "conv_seq.9.conv1.bias", "conv_seq.9.conv2.weight", "conv_seq.9.conv2.bias", "conv_seq.10.conv1.weight", "conv_seq.10.conv1.bias", "conv_seq.10.conv2.weight", "conv_seq.10.conv2.bias", "conv_seq.11.conv1.weight", "conv_seq.11.conv1.bias", "conv_seq.11.conv2.weight", "conv_seq.11.conv2.bias", "conv_seq.12.conv1.weight", "conv_seq.12.conv1.bias", "conv_seq.12.conv2.weight", "conv_seq.12.conv2.bias", "conv_seq.13.conv1.weight", "conv_seq.13.conv1.bias", "conv_seq.13.conv2.weight", "conv_seq.13.conv2.bias", "conv_seq.14.conv1.weight", "conv_seq.14.conv1.bias", "conv_seq.14.conv2.weight", "conv_seq.14.conv2.bias", "conv_seq.15.conv1.weight", "conv_seq.15.conv1.bias", "conv_seq.15.conv2.weight", "conv_seq.15.conv2.bias", "conv.weight", "conv.bias", "output_conv.weight", "output_conv.bias". 

造成改原因的是字典内容表示不匹配。我们打开“load_state_dict()”函数查看内容。

    def load_state_dict(self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]],
                        strict: bool = True):



        r"""Copies parameters and buffers from :attr:`state_dict` into
        this module and its descendants. If :attr:`strict` is ``True``, then
        the keys of :attr:`state_dict` must exactly match the keys returned
        by this module's :meth:`~torch.nn.Module.state_dict` function.

发现当strict为Ture时参数与模型的字典必须完全对应,否则会报错。我们改成False报错解除。

model.load_state_dict(model_param,False)

上面虽然会报错解除,但有时也会出现一种问题,因为我们重新建立的网络模型要与训练保存的模型大小和型号要相同。
我们训练过程建立的模型如下。

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    if torch.cuda.is_available():
        # model = nn.DataParallel(model)
        model.cuda()
        print('使用GPU训练')

因此多显卡训练时利用了nn.DataParallel(model),因此测试时,在参数导入之前也要有该过程,否则会报错。如果是单卡训练,不需要上述语句,直接建立模型load参数即可。
如果没执行nn.DataParallel(model)语句,则产生的参数为:

输出为:

print(model_param.keys())
print(model_param['input_conv.weight'].size())
#odict_keys(['input_conv.weight', 'input_conv.bias', 'conv_seq.0.conv1.weight', 'conv_seq.0.conv1.bias', 'conv_seq.0.conv2.weight', 'conv_seq.0.conv2.bias', 'conv_seq.1.conv1.weight', 'conv_seq.1.conv1.bias', 'conv_seq.1.conv2.weight', 'conv_seq.1.conv2.bias', 'conv_seq.2.conv1.weight', 'conv_seq.2.conv1.bias', 'conv_seq.2.conv2.weight', 'conv_seq.2.conv2.bias', 'conv_seq.3.conv1.weight', 'conv_seq.3.conv1.bias', 'conv_seq.3.conv2.weight', 'conv_seq.3.conv2.bias', 'conv_seq.4.conv1.weight', 'conv_seq.4.conv1.bias', 'conv_seq.4.conv2.weight', 'conv_seq.4.conv2.bias', 'conv_seq.5.conv1.weight', 'conv_seq.5.conv1.bias', 'conv_seq.5.conv2.weight', 'conv_seq.5.conv2.bias', 'conv_seq.6.conv1.weight', 'conv_seq.6.conv1.bias', 'conv_seq.6.conv2.weight', 'conv_seq.6.conv2.bias', 'conv_seq.7.conv1.weight', 'conv_seq.7.conv1.bias', 'conv_seq.7.conv2.weight', 'conv_seq.7.conv2.bias', 'conv_seq.8.conv1.weight', 'conv_seq.8.conv1.bias', 'conv_seq.8.conv2.weight', 'conv_seq.8.conv2.bias', 'conv_seq.9.conv1.weight', 'conv_seq.9.conv1.bias', 'conv_seq.9.conv2.weight', 'conv_seq.9.conv2.bias', 'conv_seq.10.conv1.weight', 'conv_seq.10.conv1.bias', 'conv_seq.10.conv2.weight', 'conv_seq.10.conv2.bias', 'conv_seq.11.conv1.weight', 'conv_seq.11.conv1.bias', 'conv_seq.11.conv2.weight', 'conv_seq.11.conv2.bias', 'conv_seq.12.conv1.weight', 'conv_seq.12.conv1.bias', 'conv_seq.12.conv2.weight', 'conv_seq.12.conv2.bias', 'conv_seq.13.conv1.weight', 'conv_seq.13.conv1.bias', 'conv_seq.13.conv2.weight', 'conv_seq.13.conv2.bias', 'conv_seq.14.conv1.weight', 'conv_seq.14.conv1.bias', 'conv_seq.14.conv2.weight', 'conv_seq.14.conv2.bias', 'conv_seq.15.conv1.weight', 'conv_seq.15.conv1.bias', 'conv_seq.15.conv2.weight', 'conv_seq.15.conv2.bias', 'conv.weight', 'conv.bias', 'output_conv.weight', 'output_conv.bias'])
#torch.Size([64, 3, 3, 3])

若执行nn.DataParallel(model),输出为:

print(model_param.keys())
print(model_param['module.input_conv.weight'].size())

#odict_keys(['module.input_conv.weight', 'module.input_conv.bias', 'module.conv_seq.0.conv1.weight', 'module.conv_seq.0.conv1.bias', 'module.conv_seq.0.conv2.weight', 'module.conv_seq.0.conv2.bias', 'module.conv_seq.1.conv1.weight', 'module.conv_seq.1.conv1.bias', 'module.conv_seq.1.conv2.weight', 'module.conv_seq.1.conv2.bias', 'module.conv_seq.2.conv1.weight', 'module.conv_seq.2.conv1.bias', 'module.conv_seq.2.conv2.weight', 'module.conv_seq.2.conv2.bias', 'module.conv_seq.3.conv1.weight', 'module.conv_seq.3.conv1.bias', 'module.conv_seq.3.conv2.weight', 'module.conv_seq.3.conv2.bias', 'module.conv_seq.4.conv1.weight', 'module.conv_seq.4.conv1.bias', 'module.conv_seq.4.conv2.weight', 'module.conv_seq.4.conv2.bias', 'module.conv_seq.5.conv1.weight', 'module.conv_seq.5.conv1.bias', 'module.conv_seq.5.conv2.weight', 'module.conv_seq.5.conv2.bias', 'module.conv_seq.6.conv1.weight', 'module.conv_seq.6.conv1.bias', 'module.conv_seq.6.conv2.weight', 'module.conv_seq.6.conv2.bias', 'module.conv_seq.7.conv1.weight', 'module.conv_seq.7.conv1.bias', 'module.conv_seq.7.conv2.weight', 'module.conv_seq.7.conv2.bias', 'module.conv_seq.8.conv1.weight', 'module.conv_seq.8.conv1.bias', 'module.conv_seq.8.conv2.weight', 'module.conv_seq.8.conv2.bias', 'module.conv_seq.9.conv1.weight', 'module.conv_seq.9.conv1.bias', 'module.conv_seq.9.conv2.weight', 'module.conv_seq.9.conv2.bias', 'module.conv_seq.10.conv1.weight', 'module.conv_seq.10.conv1.bias', 'module.conv_seq.10.conv2.weight', 'module.conv_seq.10.conv2.bias', 'module.conv_seq.11.conv1.weight', 'module.conv_seq.11.conv1.bias', 'module.conv_seq.11.conv2.weight', 'module.conv_seq.11.conv2.bias', 'module.conv_seq.12.conv1.weight', 'module.conv_seq.12.conv1.bias', 'module.conv_seq.12.conv2.weight', 'module.conv_seq.12.conv2.bias', 'module.conv_seq.13.conv1.weight', 'module.conv_seq.13.conv1.bias', 'module.conv_seq.13.conv2.weight', 'module.conv_seq.13.conv2.bias', 'module.conv_seq.14.conv1.weight', 'module.conv_seq.14.conv1.bias', 'module.conv_seq.14.conv2.weight', 'module.conv_seq.14.conv2.bias', 'module.conv_seq.15.conv1.weight', 'module.conv_seq.15.conv1.bias', 'module.conv_seq.15.conv2.weight', 'module.conv_seq.15.conv2.bias', 'module.conv.weight', 'module.conv.bias', 'module.output_conv.weight', 'module.output_conv.bias'])
#torch.Size([64, 3, 3, 3])

你可能感兴趣的:(深度学习,pytorch,python,人工智能,深度学习,神经网络,pytorch)