RuntimeError: Error(s) in loading state_dict for XXX

用torch保存训练后的模型

torch.save(model.state_dict(), 'file.pkl')

重新导入用于测试

model.load_state_dict(torch.load( 'file.pkl'))
model = model.cuda()

报错:

RuntimeError: Error(s) in loading state_dict for XXX:
	Missing key(s) in state_dict: "conv1.0.conv.weight", "conv1.1.weight", "conv1.1.bias", "conv1.1.running_mean", "conv1.1.running_var", "Block1.0.conv.weight", "Block1.1.weight", "Block1.1.bias", "Block1.1.running_mean", "Block1.1.running_var", "Block2.0.conv.weight", "Block2.1.weight", "Block2.1.bias", "Block2.1.running_mean", "Block2.1.running_var", "Block3.0.conv.weight", "Block3.1.weight", "Block3.1.bias", "Block3.1.running_mean", "Block3.1.running_var", "lastconv1.0.conv.weight", "lastconv1.1.weight", "lastconv1.1.bias", "lastconv1.1.running_mean", "lastconv1.1.running_var", "lastconv1.3.conv.weight", "sa1.conv1.weight", "sa2.conv1.weight", "sa3.conv1.weight". 
	Unexpected key(s) in state_dict: "module.conv1.0.conv.weight", "module.conv1.1.weight", "module.conv1.1.bias", "module.conv1.1.running_mean", "module.conv1.1.running_var", "module.conv1.1.num_batches_tracked", "module.Block1.0.conv.weight", "module.Block1.1.weight", "module.Block1.1.bias", "module.Block1.1.running_mean", "module.Block1.1.running_var", "module.Block1.1.num_batches_tracked", "module.Block2.0.conv.weight", "module.Block2.1.weight", "module.Block2.1.bias", "module.Block2.1.running_mean", "module.Block2.1.running_var", "module.Block2.1.num_batches_tracked", "module.Block3.0.conv.weight", "module.Block3.1.weight", "module.Block3.1.bias", "module.Block3.1.running_mean", "module.Block3.1.running_var", "module.Block3.1.num_batches_tracked", "module.lastconv1.0.conv.weight", "module.lastconv1.1.weight", "module.lastconv1.1.bias", "module.lastconv1.1.running_mean", "module.lastconv1.1.running_var", "module.lastconv1.1.num_batches_tracked", "module.lastconv1.3.conv.weight", "module.sa1.conv1.weight", "module.sa2.conv1.weight", "module.sa3.conv1.weight". 

有很多分享说是版本不同的问题,只要将load_state_dict(state_dict) 改成 model.load_state_dict(state_dict, False),即

model.load_state_dict(torch.load( 'file.pkl'), False)

这样做确实是不会报错了,但是相当于简单粗暴的不管这些不对应的key,最后整个模型都不可用,预测出的数值都非常离谱

琢磨了很久发现原因,训练时使用多GPU并行操作,使用了torch.nn.DataParallel封装

model = torch.nn.DataParallel(model, device_ids=[0,1,2]).cuda()

而重新导入的时候没有用torch.nn.DataParallel操作,那么解决方法非常简单,只要在原有的代码中加一行model=nn.DataParallel(model) 即可,也就是

model = nn.DataParallel(model)
model.load_state_dict(torch.load( 'file.pkl'))
model = model.cuda()

问题解决

你可能感兴趣的:(深度学习)