es/module.py", line 719, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.features.0.0.weight", "module.features.0.1.weight", "module.features.0.1.bias", "module.features.0.1.running_mean", "module.features.0.1.running_var", "module.features.1.conv.0.weight", "module.features.1.conv.1.weight", "module.features.1.conv.1.bias", "module.features.1.conv.1.running_mean", "module.features.1.conv.1.running_var", "module.features.1.conv.3.weight", "module.features.1.conv.4.weight", "module.features.1.conv.4.bias", "module.features.1.conv.4.running_mean", "module.features.1.conv.4.running_var", "module.features.2.conv.0.weight", "module.features.2.conv.1.weight", "module.features.2.conv.1.bias", "module.features.2.conv.1.running_mean", "module.features.2.conv.1.running_var", "module.features.2.conv.3.weight", "module.features.2.conv.4.weight", "module.features.2.conv.4.bias", "module.features.2.conv.4.running_mean", "module.features.2.conv.4.running_var", "module.features.2.conv.6.weight", "module.features.2.conv.7.weight", "module.features.2.conv.7.bias", "module.features.2.conv.7.running_mean", "module.features.2.conv.7.running_var", "module.features.3.conv.0.weight", "module.features.3.conv.1.weight", "module.features.3.conv.1.bias", "module.features.3.conv.1.running_mean", "module.features.3.conv.1.running_var", "module.features.3.conv.3.weight", "module.features.3.conv.4.weight", "module.features.3.conv.4.bias", "module.features.3.conv.4.running_mean", "module.features.3.conv.4.running_var", "module.features.3.conv.6.weight", "module.features.3.conv.7.weight", "module.features.3.conv.7.bias", "module.features.3.conv.7.running_mean", "module.features.3.conv.7.running_var", "module.features.4.conv.0.weight", "module.features.4.conv.1.weight", "module.features.4.conv.1.bias", "module.features.4.conv.1.running_mean", "module.features.4.conv.1.running_var", "module.features.4.conv.3.weight", "module.features.4.conv.4.weight", "module.features.4.conv.4.bias", "module.features.4.conv.4.running_mean", "module.features.4.conv.4.running_var", "module.features.4.conv.6.weight", "module.features.4.conv.7.weight", "module.features.4.conv.7.bias", "module.features.4.conv.7.running_mean", "module.features.4.conv.7.running_var", "module.features.5.conv.0.weight", "module.features.5.conv.1.weight", "module.features.5.conv.1.bias", "module.features.5.conv.1.running_mean", "module.features.5.conv.1.running_var", "module.features.5.conv.3.weight", "module.features.5.conv.4.weight", "module.features.5.conv.4.bias", "module.features.5.conv.4.running_mean", "module.features.5.conv.4.running_var", "module.features.5.conv.6.weight", "module.features.5.conv.7.weight", "module.features.5.conv.7.bias", "module.features.5.conv.7.running_mean", "module.features.5.conv.7.running_var", "module.features.6.conv.0.weight", "module.features.6.conv.1.weight", "module.features.6.conv.1.bias", "module.features.6.conv.1.running_mean", "module.features.6.conv.1.running_var", "module.features.6.conv.3.weight", "module.features.6.conv.4.weight", "module.features.6.conv.4.bias", "module.features.6.conv.4.running_mean", "module.features.6.conv.4.running_var", "module.features.6.conv.6.weight", "module.features.6.conv.7.weight", "module.features.6.conv.7.bias", "module.features.6.conv.7.running_mean", "module.features.6.conv.7.running_var", "module.features.7.conv.0.weight", "module.features.7.conv.1.weight", "module.features.7.conv.1.bias", "module.features.7.conv.1.running_mean", "module.features.7.conv.1.running_var", "module.features.7.conv.3.weight", "module.features.7.conv.4.weight", "module.features.7.conv.4.bias", "module.features.7.conv.4.running_mean", "module.features.7.conv.4.running_var", "module.features.7.conv.6.weight", "module.features.7.conv.7.weight", "module.features.7.conv.7.bias", "module.features.7.conv.7.running_mean", "module.features.7.conv.7.running_var", "module.features.8.conv.0.weight", "module.features.8.conv.1.weight", "module.features.8.conv.1.bias", "module.features.8.conv.1.running_mean", "module.features.8.conv.1.running_var", "module.features.8.conv.3.weight", "module.features.8.conv.4.weight", "module.features.8.conv.4.bias", "module.features.8.conv.4.running_mean", "module.features.8.conv.4.running_var", "module.features.8.conv.6.weight", "module.features.8.conv.7.weight", "module.features.8.conv.7.bias", "module.features.8.conv.7.running_mean", "module.features.8.conv.7.running_var", "module.features.9.conv.0.weight", "module.features.9.conv.1.weight", "module.features.9.conv.1.bias", "module.features.9.conv.1.running_mean", "module.features.9.conv.1.running_var", "module.features.9.conv.3.weight", "module.features.9.conv.4.weight", "module.features.9.conv.4.bias", "module.features.9.conv.4.running_mean", "module.features.9.conv.4.running_var", "module.features.9.conv.6.weight", "module.features.9.conv.7.weight", "module.features.9.conv.7.bias", "module.features.9.conv.7.running_mean", "module.features.9.conv.7.running_var", "module.features.10.conv.0.weight", "module.features.10.conv.1.weight", "module.features.10.conv.1.bias", "module.features.10.conv.1.running_mean", "module.features.10.conv.1.running_var", "module.features.10.conv.3.weight", "module.features.10.conv.4.weight", "module.features.10.conv.4.bias", "module.features.10.conv.4.running_mean", "module.features.10.conv.4.running_var", "module.features.10.conv.6.weight", "module.features.10.conv.7.weight", "module.features.10.conv.7.bias", "module.features.10.conv.7.running_mean", "module.features.10.conv.7.running_var", "module.features.11.conv.0.weight", "module.features.11.conv.1.weight", "module.features.11.conv.1.bias", "module.features.11.conv.1.running_mean", "module.features.11.conv.1.running_var", "module.features.11.conv.3.weight", "module.features.11.conv.4.weight", "module.features.11.conv.4.bias", "module.features.11.conv.4.running_mean", "module.features.11.conv.4.running_var", "module.features.11.conv.6.weight", "module.features.11.conv.7.weight", "module.features.11.conv.7.bias", "module.features.11.conv.7.running_mean", "module.features.11.conv.7.running_var", "module.features.12.conv.0.weight", "module.features.12.conv.1.weight", "module.features.12.conv.1.bias", "module.features.12.conv.1.running_mean", "module.features.12.conv.1.running_var", "module.features.12.conv.3.weight", "module.features.12.conv.4.weight", "module.features.12.conv.4.bias", "module.features.12.conv.4.running_mean", "module.features.12.conv.4.running_var", "module.features.12.conv.6.weight", "module.features.12.conv.7.weight", "module.features.12.conv.7.bias", "module.features.12.conv.7.running_mean", "module.features.12.conv.7.running_var", "module.features.13.conv.0.weight", "module.features.13.conv.1.weight", "module.features.13.conv.1.bias", "module.features.13.conv.1.running_mean", "module.features.13.conv.1.running_var", "module.features.13.conv.3.weight", "module.features.13.conv.4.weight", "module.features.13.conv.4.bias", "module.features.13.conv.4.running_mean", "module.features.13.conv.4.running_var", "module.features.13.conv.6.weight", "module.features.13.conv.7.weight", "module.features.13.conv.7.bias", "module.features.13.conv.7.running_mean", "module.features.13.conv.7.running_var", "module.features.14.conv.0.weight", "module.features.14.conv.1.weight", "module.features.14.conv.1.bias", "module.features.14.conv.1.running_mean", "module.features.14.conv.1.running_var", "module.features.14.conv.3.weight", "module.features.14.conv.4.weight", "module.features.14.conv.4.bias", "module.features.14.conv.4.running_mean", "module.features.14.conv.4.running_var", "module.features.14.conv.6.weight", "module.features.14.conv.7.weight", "module.features.14.conv.7.bias", "module.features.14.conv.7.running_mean", "module.features.14.conv.7.running_var", "module.features.15.conv.0.weight", "module.features.15.conv.1.weight", "module.features.15.conv.1.bias", "module.features.15.conv.1.running_mean", "module.features.15.conv.1.running_var", "module.features.15.conv.3.weight", "module.features.15.conv.4.weight", "module.features.15.conv.4.bias", "module.features.15.conv.4.running_mean", "module.features.15.conv.4.running_var", "module.features.15.conv.6.weight", "module.features.15.conv.7.weight", "module.features.15.conv.7.bias", "module.features.15.conv.7.running_mean", "module.features.15.conv.7.running_var", "module.features.16.conv.0.weight", "module.features.16.conv.1.weight", "module.features.16.conv.1.bias", "module.features.16.conv.1.running_mean", "module.features.16.conv.1.running_var", "module.features.16.conv.3.weight", "module.features.16.conv.4.weight", "module.features.16.conv.4.bias", "module.features.16.conv.4.running_mean", "module.features.16.conv.4.running_var", "module.features.16.conv.6.weight", "module.features.16.conv.7.weight", "module.features.16.conv.7.bias", "module.features.16.conv.7.running_mean", "module.features.16.conv.7.running_var", "module.features.17.conv.0.weight", "module.features.17.conv.1.weight", "module.features.17.conv.1.bias", "module.features.17.conv.1.running_mean", "module.features.17.conv.1.running_var", "module.features.17.conv.3.weight", "module.features.17.conv.4.weight", "module.features.17.conv.4.bias", "module.features.17.conv.4.running_mean", "module.features.17.conv.4.running_var", "module.features.17.conv.6.weight", "module.features.17.conv.7.weight", "module.features.17.conv.7.bias", "module.features.17.conv.7.running_mean", "module.features.17.conv.7.running_var", "module.features.18.0.weight", "module.features.18.1.weight", "module.features.18.1.bias", "module.features.18.1.running_mean", "module.features.18.1.running_var", "module.classifier.1.weight", "module.classifier.1.bias".
Unexpected key(s) in state_dict: "epoch", "arch", "state_dict", "best_prec1", "optimizer".
因为最近要用到pytroch版本的mobilenet2,训练好数据之后,使用torch.load导入训练好的模型一直报错:
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DataParallel:
或者Unexpected key(s) in state_dict: "epoch", "arch", "state_dict", "best_prec1", "optimizer".
这里我从网上找到的三个解决方案:我自己试了故障依旧,而且感觉会拖慢模型的速度。
第一种方案是因为你加入了模型前使用了torch.nn.DataParallel(),而此时的训练并没有使用,则会出现这样的错误。可以在你导入模型前加入这样一行代码:
model = torch.nn.DataParallel(model)
cudnn.benchmark = True
第二种方案是
# original saved file with DataParallel
state_dict = torch.load('myfile.pth')
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)
第三种方案为:
model.load_state_dict({k.replace('module.',''):v for k,v in torch.load('myfile.pth').items()})
以上的方案来自https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/3和https://blog.csdn.net/kaixinjiuxing666/article/details/85115077
但都试过之后无法解决我的问题。我最终是这样解决的!!!
第四种方案:我的方案
model = MobileNetV2()
checkpoint = torch.load(modelpath) #modelpath是你要加载训练好的模型文件地址
model.load_state_dict(checkpoint['state_dict'])
output = model(x)
因为你训练好的模型文件好像字典键值有很多个,包括epoch等,但我们只需要模型参数文件。报错的原因是因为载入模型文件的键值太多了。pytorch识别不了。
成功解决!!!!!!!