单卡加载多卡训练保存的模型

1、问题
直接用加载单卡模型的代码来加载多卡训练保存的模型时会报这样一个错误:
RuntimeError: Error(s) in loading state_dict for : Missing key(s) in state_dict

2、原因
原因很简单,就是:模型在DataParallel或者DDP训练后保存的键值有module. ,对应的网络的键值则没有module.

3、解决

path = "model_0.pt"
new_state_dict={}
for k, v in torch.load(str(path)).items():
    new_state_dict[k[7:]] = v			#键值包含‘module.’ 则删除 
self.load_state_dict(new_state_dict,strict=False)

另外可参考博客:
https://blog.csdn.net/lei_qi/article/details/118607234?spm=1001.2014.3001.5502

https://zhuanlan.zhihu.com/p/371090724

你可能感兴趣的:(python,开发语言)