pytorch加载模型报错RuntimeError:Error(s) in loading state_dict for DataParallel

完整报错信息:
RuntimeError:Error(s) in loading state_dict for DataParallel:
Unexpected key(s) in state_dict:“module.resnet.bn1.num_batches_tracked”,"module.resnet.layer1.0.bn1.num_batches_tracked"等等,遇到这种错误,说明你训练模型和测试加载模型所使用的环境不一致,解决方法:
1>将环境改为一致
2>我当时的环境是:训练pytorch1.0,测试环境Pytorch0.4,只需要把加载模型那一块做一个简单的修改,如以下代码:

checkpoint_file = os.path.join(args.checkpoint, args.test+'.pth.tar')
checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint['state_dict'])

上面是我出错的代码,解决方法是在最后一行的括号里面加上False,如下:

checkpoint_file = os.path.join(args.checkpoint, args.test+'.pth.tar')
checkpoint = torch.load(checkpoint_file) 
model.load_state_dict(checkpoint['state_dict'],False) # 修改处

你可能感兴趣的:(深度学习)