在加载已经训练好的模型时,报错。
Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: “module.conv0.weight”, “module.bn0.weight”, “module.bn0.bias”, “module.bn0.running_mean”, “module.bn0.running_var”, “module.conv1.weight”, “module.bn1.weight”, “module.bn1.bias”, “module.bn1.running_mean”, “module.bn1.running_var”, “module.conv2.weight”, “module.bn2.weight”, “module.bn2.bias”, “module.bn2.running_mean”, “module.bn2.running_var”, “module.conv3.weight”, “module.bn3.weight”, “module.bn3.bias”, “module.bn3.running_mean”, “module.bn3.running_var”, “module.conv4.weight”, “module.bn4.weight”, “module.bn4.bias”, “module.bn4.running_mean”, “module.bn4.running_var”, “module.conv5.weight”, “module.bn5.weight”, “module.bn5.bias”, “module.bn5.running_mean”, “module.bn5.running_var”, “module.fc.weight”, “module.fc.bias”.
Unexpected key(s) in state_dict: “epoch”, “state_dict”, “best_prec1”.
保存模型的代码:
save_checkpoint({
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
}, is_best, filename=os.path.join(args.save_dir, 'model.th'))
该函数:
# 保存最新和最佳模型
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
"""
Save the training model
"""
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
由以上可知,在调用torch.save函数时,state对应的是{ 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }
,因此在加载的时候要指明值。
修改为:
model.load_state_dict(torch.load("C:\\Users\\83543\\Desktop\\model_best.pth.tar")['state_dict'])
导入的是ResNet-20的模型参数的模型文件,但是引入的模型框架是vgg的导致错误,记得对应。