RuntimeError: Error(s) in loading state_dic ,Missing key(s) in state_dict , Unexpected key(s)

  1. pytorch加载模型错误信息:
RuntimeError: Error(s) in loading state_dict for SimpleDLA:
	Missing key(s) in state_dict: "base.0.weight", "base.1.weight", "base.1.bias", "base.1.running_mean", "base.1.running_var", "layer1.0.weight", "layer1.1.weight", "layer1.1.bias", "layer1.1.running_mean", "layer1.1.running_var", "layer2.0.weight", "layer2.1.weight", "layer2.1.bias", "layer2.1.running_mean", "layer2.1.running_var", "layer3.root.conv.weight", "layer3.root.bn.weight", "layer3.root.bn.bias", "layer3.root.bn.running_mean", "layer3.root.bn.running_var", "layer3.left_tree.conv1.weight", "layer3.left_tree.bn1.weight", "layer3.left_tree.bn1.bias", "layer3.left_tree.bn1.running_mean", "layer3.left_tree.bn1.running_var", "layer3.left_tree.conv2.weight", "layer3.left_tree.bn2.weight", "layer3.left_tree.bn2.bias", "layer3.left_tree.bn2.running_mean", "layer3.left_tree.bn2.running_var", "layer3.left_tree.shortcut.0.weight", "layer3.left_tree.shortcut.1.weight", "layer3.left_tree.shortcut.1.bias", "layer3.left_tree.shortcut.1.running_mean", "layer3.left_tree.shortcut.1.running_var", "layer3.right_tree.conv1.weight", "layer3.right_tree.bn1.weight", "layer3.right_tree.bn1.bias", "layer3.right_tree.bn1.running_mean", "layer3.right_tree.bn1.running_var", "layer3.right_tree.conv2.weight", "layer3.right_tree.bn2.weight", "layer3.right_tree.bn2.bias", "layer3.right_tree.bn2.running_mean", "layer3.right_tree.bn2.running_var", "layer4.root.conv.weight", "layer4.root.bn.weight", "layer4.root.bn.bias", "layer4.root.bn.running_mean", "layer4.root.bn.running_var", "layer4.left_tree.root.conv.weight", "layer4.left_tree.root.bn.weight", "layer4.left_tree.root.bn.bias", "layer4.left_tree.root.bn.running_mean", "layer4.left_tree.root.bn.running_var", "layer4.left_tree.left_tree.conv1.weight", "layer4.left_tree.left_tree.bn1.weight", "layer4.left_tree.left_tree.bn1.bias", "layer4.left_tree.left_tree.bn1.running_mean", "layer4.left_tree.left_tree.bn1.running_var", "layer4.left_tree.left_tree.conv2.weight", "layer4.left_tree.left_tree.bn2.weight", "layer4.left_tree.left_tree.bn2.bias", "layer4.left_tree.left_tree.bn2.running_mean", "layer4.left_tree.left_tree.bn2.running_var", "layer4.left_tree.left_tree.shortcut.0.weight", "layer4.left_tree.left_tree.shortcut.1.weight", "layer4.left_tree.left_tree.shortcut.1.bias", "layer4.left_tree.left_tree.shortcut.1.running_mean", "layer4.left_tree.left_tree.shortcut.1.running_var", "layer4.left_tree.right_tree.conv1.weight", "layer4.left_tree.right_tree.bn1.weight", "layer4.left_tree.right_tree.bn1.bias", "layer4.left_tree.right_tree.bn1.running_mean", "layer4.left_tree.right_tree.bn1.running_var", "layer4.left_tree.right_tree.conv2.weight", "layer4.left_tree.right_tree.bn2.weight", "layer4.left_tree.right_tree.bn2.bias", "layer4.left_tree.right_tree.bn2.running_mean", "layer4.left_tree.right_tree.bn2.running_var", "layer4.right_tree.root.conv.weight", "layer4.right_tree.root.bn.weight", "layer4.right_tree.root.bn.bias", "layer4.right_tree.root.bn.running_mean", "layer4.right_tree.root.bn.running_var", "layer4.right_tree.left_tree.conv1.weight", "layer4.right_tree.left_tree.bn1.weight", "layer4.right_tree.left_tree.bn1.bias", "layer4.right_tree.left_tree.bn1.running_mean", "layer4.right_tree.left_tree.bn1.running_var", "layer4.right_tree.left_tree.conv2.weight", "layer4.right_tree.left_tree.bn2.weight", "layer4.right_tree.left_tree.bn2.bias", "layer4.right_tree.left_tree.bn2.running_mean", "layer4.right_tree.left_tree.bn2.running_var", "layer4.right_tree.right_tree.conv1.weight", "layer4.right_tree.right_tree.bn1.weight", "layer4.right_tree.right_tree.bn1.bias", "layer4.right_tree.right_tree.bn1.running_mean", "layer4.right_tree.right_tree.bn1.running_var", "layer4.right_tree.right_tree.conv2.weight", "layer4.right_tree.right_tree.bn2.weight", "layer4.right_tree.right_tree.bn2.bias", "layer4.right_tree.right_tree.bn2.running_mean", "layer4.right_tree.right_tree.bn2.running_var", "layer5.root.conv.weight", "layer5.root.bn.weight", "layer5.root.bn.bias", "layer5.root.bn.running_mean", "layer5.root.bn.running_var", "layer5.left_tree.root.conv.weight", "layer5.left_tree.root.bn.weight", "layer5.left_tree.root.bn.bias", "layer5.left_tree.root.bn.running_mean", "layer5.left_tree.root.bn.running_var", "layer5.left_tree.left_tree.conv1.weight", "layer5.left_tree.left_tree.bn1.weight", "layer5.left_tree.left_tree.bn1.bias", "layer5.left_tree.left_tree.bn1.running_mean", "layer5.left_tree.left_tree.bn1.running_var", "layer5.left_tree.left_tree.conv2.weight", "layer5.left_tree.left_tree.bn2.weight", "layer5.left_tree.left_tree.bn2.bias", "layer5.left_tree.left_tree.bn2.running_mean", "layer5.left_tree.left_tree.bn2.running_var", "layer5.left_tree.left_tree.shortcut.0.weight", "layer5.left_tree.left_tree.shortcut.1.weight", "layer5.left_tree.left_tree.shortcut.1.bias", "layer5.left_tree.left_tree.shortcut.1.running_mean", "layer5.left_tree.left_tree.shortcut.1.running_var", "layer5.left_tree.right_tree.conv1.weight", "layer5.left_tree.right_tree.bn1.weight", "layer5.left_tree.right_tree.bn1.bias", "layer5.left_tree.right_tree.bn1.running_mean", "layer5.left_tree.right_tree.bn1.running_var", "layer5.left_tree.right_tree.conv2.weight", "layer5.left_tree.right_tree.bn2.weight", "layer5.left_tree.right_tree.bn2.bias", "layer5.left_tree.right_tree.bn2.running_mean", "layer5.left_tree.right_tree.bn2.running_var", "layer5.right_tree.root.conv.weight", "layer5.right_tree.root.bn.weight", "layer5.right_tree.root.bn.bias", "layer5.right_tree.root.bn.running_mean", "layer5.right_tree.root.bn.running_var", "layer5.right_tree.left_tree.conv1.weight", "layer5.right_tree.left_tree.bn1.weight", "layer5.right_tree.left_tree.bn1.bias", "layer5.right_tree.left_tree.bn1.running_mean", "layer5.right_tree.left_tree.bn1.running_var", "layer5.right_tree.left_tree.conv2.weight", "layer5.right_tree.left_tree.bn2.weight", "layer5.right_tree.left_tree.bn2.bias", "layer5.right_tree.left_tree.bn2.running_mean", "layer5.right_tree.left_tree.bn2.running_var", "layer5.right_tree.right_tree.conv1.weight", "layer5.right_tree.right_tree.bn1.weight", "layer5.right_tree.right_tree.bn1.bias", "layer5.right_tree.right_tree.bn1.running_mean", "layer5.right_tree.right_tree.bn1.running_var", "layer5.right_tree.right_tree.conv2.weight", "layer5.right_tree.right_tree.bn2.weight", "layer5.right_tree.right_tree.bn2.bias", "layer5.right_tree.right_tree.bn2.running_mean", "layer5.right_tree.right_tree.bn2.running_var", "layer6.root.conv.weight", "layer6.root.bn.weight", "layer6.root.bn.bias", "layer6.root.bn.running_mean", "layer6.root.bn.running_var", "layer6.left_tree.conv1.weight", "layer6.left_tree.bn1.weight", "layer6.left_tree.bn1.bias", "layer6.left_tree.bn1.running_mean", "layer6.left_tree.bn1.running_var", "layer6.left_tree.conv2.weight", "layer6.left_tree.bn2.weight", "layer6.left_tree.bn2.bias", "layer6.left_tree.bn2.running_mean", "layer6.left_tree.bn2.running_var", "layer6.left_tree.shortcut.0.weight", "layer6.left_tree.shortcut.1.weight", "layer6.left_tree.shortcut.1.bias", "layer6.left_tree.shortcut.1.running_mean", "layer6.left_tree.shortcut.1.running_var", "layer6.right_tree.conv1.weight", "layer6.right_tree.bn1.weight", "layer6.right_tree.bn1.bias", "layer6.right_tree.bn1.running_mean", "layer6.right_tree.bn1.running_var", "layer6.right_tree.conv2.weight", "layer6.right_tree.bn2.weight", "layer6.right_tree.bn2.bias", "layer6.right_tree.bn2.running_mean", "layer6.right_tree.bn2.running_var", "linear.weight", "linear.bias". 
	Unexpected key(s) in state_dict: "module.base.0.weight", "module.base.1.weight", "module.base.1.bias", "module.base.1.running_mean", "module.base.1.running_var", "module.base.1

如下图所示:
在这里插入图片描述
2. 错误意思指:
表明加载模型时参数字典中state_dict[]缺失了一些键,如"base.0.weight", “base.1.weight”, “base.1.bias"等键,出现了一些不必要的键,如"module.base.0.weight”
3. 原因:
模型训练时使用了多张GPU并行训练,出现下面几条语句:

    model = torch.nn.DataParallel(model)
    cudnn.benchmark = True
从而使训练好后保存的模型参数键值对中键开头多出现了"module."字符串,

4.解决方法:将不希望出现的键删除,将缺失的键添加进来,也即是将dict[key,value]键值对中的key全部去掉“module.”前缀,
具体代码如下:

model_cifar = SimpleDLA()
checkpoint = torch.load("pytorch_model.pth", map_location="cpu")['net']
print("key:",checkpoint.keys())
for key in list(checkpoint.keys()):
    if 'module.' in key:
        checkpoint[key[7:]] = checkpoint[key] #全部key去掉“module.”前缀
        del checkpoint[key]
print("key2:",checkpoint.keys())
model_cifar.load_state_dict(checkpoint)
#下面这段代码也正确:
model = C3D_model.C3D(num_classes=101)
checkpoint = torch.load('run/run_10/models/C3D-ucf101_epoch-99.pth.tar', map_location=lambda storage, loc: storage)
state_dict = model.state_dict()
for k1, k2 in zip(state_dict.keys(), checkpoint.keys()):
    state_dict[k1] = checkpoint[k2]
model.load_state_dict(state_dict)
  1. 注意:模型训练好后只保存了模型的参数,保存格式如下:因此上面加载模型参数时需要根据"net"键使用state_dict[‘net’] 获取模型训练好的参数
   state_dict = {
       'net': net.state_dict(),
        'acc': acc,
         'epoch': epoch,
   }
   torch.save(state_dict, 'pytorch_model.pth')
  1. 另外CSDN上面对这个报错千篇一律说将load_state_dict(state_dict) 改成 model.load_state_dict(state_dict, False),完全不起作用,误人子弟
    RuntimeError: Error(s) in loading state_dic ,Missing key(s) in state_dict , Unexpected key(s)_第1张图片
  2. 重要最后是在pytorch上面找到了答案:
    https://discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/13
    如下图所示
    RuntimeError: Error(s) in loading state_dic ,Missing key(s) in state_dict , Unexpected key(s)_第2张图片
  3. 另外一种解决方法:
在加载权重参数之前,初始化网络时,使用并行初始化网络,然后使用net.module.load_state_dict并行加载保存下来的网络权重参数:
net = nn.DataParallel(net,device_ids = devices (你的所有gpu)).to(devices[0])
net.module.load_state_dict(torch.load('pretrained.params'))

具体代码如下:

def load_pretrained_model(pretrained_model,num_hiddens,ffn_num_hiddens,num_heads,num_layers,dropout,max_len,devices):
    data_dir = d2l.torch.download_extract(pretrained_model)
    vocab = d2l.torch.Vocab()
    vocab.idx_to_token = json.load(open(os.path.join(data_dir,'vocab.json')))
    vocab.token_to_idx = {token:idx for idx,token in enumerate(vocab.idx_to_token)}
    bert = d2l.torch.BERTModel(len(vocab),num_hiddens=num_hiddens,norm_shape=[256],ffn_num_input=256,ffn_num_hiddens=ffn_num_hiddens,num_heads=num_heads,num_layers=num_layers,dropout=dropout,max_len=max_len,key_size=256,query_size=256,value_size=256,hid_in_features=256,mlm_in_features=256,nsp_in_features=256)
    bert = nn.DataParallel(bert,device_ids=devices).to(devices[0])
    bert.module.load_state_dict(torch.load(os.path.join(data_dir,'pretrained.params')))

    #bert.load_state_dict(torch.load(os.path.join(data_dir,'pretrained.params')))
    return bert,vocab
devices = d2l.torch.try_all_gpus()
bert,vocab = load_pretrained_model('bert.small',num_hiddens=256,ffn_num_hiddens=512,num_heads=4,num_layers=2,dropout=0.1,max_len=512,devices=devices)

二 相关链接

AttributeError: ‘DataParallel’ object has no attribute ‘xxx’
Fine tuning resnet: ‘DataParallel’ object has no attribute ‘fc’

你可能感兴趣的:(pytorch,python)