Error(s) pytorch 加载checkpoint state_dict出错:Missing key(s) && Unexpected key(s) in state_dict

ERROR: 

Traceback (most recent call last):
File "test_0.py", line 130, in 
model = load_model()
File "test_0.py", line 104, in load_model
model.load_state_dict(checkpoint['state_dict'])
File "/home/cosmo/anaconda3/envs/tf8/lib/python3.6/site-packages/torch/nn/modules/module.py", line 719, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.features.0.weight", "module.features.0.bias", "module.features.3.weight", "module.features.3.bias", "module.features.6.weight", "module.features.6.bias", "module.features.8.weight", "module.features.8.bias", "module.features.10.weight", "module.features.10.bias", "module.classifier.weight", "module.classifier.bias".
Unexpected key(s) in state_dict: "features.module.0.weight", "features.module.0.bias", "features.module.3.weight", "features.module.3.bias", "features.module.6.weight", "features.module.6.bias", "features.module.8.weight", "features.module.8.bias", "features.module.10.weight", "features.module.10.bias", "classifier.weight", "classifier.bias".

REASON:

Error(s) pytorch 加载checkpoint state_dict出错:Missing key(s) && Unexpected key(s) in state_dict_第1张图片

 The problem is the module is load with dataparallel activated and you are trying to load it without data parallel. That's why there's an extra module at the beginning of each key!

错误原因就是net.load_state_dict的时候,net的状态不是处在gpu并行状态,而存储的net模型checkpoint是在gpu并行状态下的!

SOLVE:

在net.load_state_dict前将net的状态设置成gpu并行模式就好了。

Error(s) pytorch 加载checkpoint state_dict出错:Missing key(s) && Unexpected key(s) in state_dict_第2张图片

你可能感兴趣的:(Pytorch)