Error(s) in loading state_dict for DataParallel: Missing key(s) in state_dict: “module.conv0.weight

在加载已经训练好的模型时,报错。

报错描述:

Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: “module.conv0.weight”, “module.bn0.weight”, “module.bn0.bias”, “module.bn0.running_mean”, “module.bn0.running_var”, “module.conv1.weight”, “module.bn1.weight”, “module.bn1.bias”, “module.bn1.running_mean”, “module.bn1.running_var”, “module.conv2.weight”, “module.bn2.weight”, “module.bn2.bias”, “module.bn2.running_mean”, “module.bn2.running_var”, “module.conv3.weight”, “module.bn3.weight”, “module.bn3.bias”, “module.bn3.running_mean”, “module.bn3.running_var”, “module.conv4.weight”, “module.bn4.weight”, “module.bn4.bias”, “module.bn4.running_mean”, “module.bn4.running_var”, “module.conv5.weight”, “module.bn5.weight”, “module.bn5.bias”, “module.bn5.running_mean”, “module.bn5.running_var”, “module.fc.weight”, “module.fc.bias”.
Unexpected key(s) in state_dict: “epoch”, “state_dict”, “best_prec1”.

原因:

保存模型的代码:

save_checkpoint({
        'state_dict': model.state_dict(),
        'best_prec1': best_prec1,
    }, is_best, filename=os.path.join(args.save_dir, 'model.th'))

该函数:

# 保存最新和最佳模型
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    """
    Save the training model
    """
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

由以上可知,在调用torch.save函数时,state对应的是{ 'state_dict': model.state_dict(), 'best_prec1': best_prec1, },因此在加载的时候要指明值。

解决方案:

修改为:

model.load_state_dict(torch.load("C:\\Users\\83543\\Desktop\\model_best.pth.tar")['state_dict'])

第二种错误原因:

导入的是ResNet-20的模型参数的模型文件,但是引入的模型框架是vgg的导致错误,记得对应。

你可能感兴趣的:(torch)