加载模型出现-RuntimeError: Error(s) in loading state_dict for Net:unexpected key(s) in state_dict: XXX

加载模型时发生错误RuntimeError: Error(s) in loading state_dict for Net:unexpected key(s) in state_dict: XXX
加载模型出现-RuntimeError: Error(s) in loading state_dict for Net:unexpected key(s) in state_dict: XXX_第1张图片

Traceback (most recent call last):
  File "demo.py", line 380, in <module>
    model.load_state_dict(torch.load('./0428.pth'))
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ViT:
        Unexpected key(s) in state_dict: "transformer.skipcat.3.weight", "transformer.skipcat.3.bias", "transformer.skipcat.4.weight", "transformer.skipcat.4.bias". 

原因:

加载使用模型时和训练模型时的环境不一致.

解决方法:

将load_state_dict(state_dict) 改成 model.load_state_dict(state_dict, False)

model.load_state_dict(torch.load('models/params.pt'),strict=False)

问题解决~

你可能感兴趣的:(深度学习报错调试合集,pytorch,深度学习)