PyTorch加载模型碰到Missing key(s) in state_dict报错

问题描述

  • PyTorch训练好模型以后,需要加载模型,加载模型代码如下
ckpt = torch.load(model_path_len)
model.load_state_dict(ckpt['state_dict'])
  • 结果碰到的问题为:Missing key(s) in state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for GCN:
	Missing key(s) in state_dict: "gcbs.12.gc1.weight", "gcbs.12.gc1.att", "gcbs.12.gc1.bias", "gcbs.12.bn1.weight", "gcbs.12.bn1.bias", "gcbs.12.bn1.running_mean", "gcbs.12.bn1.running_var", "gcbs.12.gc2.weight", "gcbs.12.gc2.att", "gcbs.12.gc2.bias", "gcbs.12.bn2.weight", "gcbs.12.bn2.bias", "gcbs.12.bn2.running_mean", 
  • 解决办法,亲测有效
ckpt = torch.load(model_path_len)
model.load_state_dict(ckpt['state_dict'],strict=False)

高级做法

  • 上面的做法会导致一些参数加载不进来
  • 高级做法是把原模型pth文件的key打印与现在模型的key进行比较,手动的为模型加载参数

你可能感兴趣的:(解决问题,机器学习专题,pytorch,深度学习,python)