RuntimeError: Error(s) in loading state_dict for SimpleDLA:
Missing key(s) in state_dict: "base.0.weight", "base.1.weight", "base.1.bias", "base.1.running_mean", "base.1.running_var", "layer1.0.weight", "layer1.1.weight", "layer1.1.bias", "layer1.1.running_mean", "layer1.1.running_var", "layer2.0.weight", "layer2.1.weight", "layer2.1.bias", "layer2.1.running_mean", "layer2.1.running_var", "layer3.root.conv.weight", "layer3.root.bn.weight", "layer3.root.bn.bias", "layer3.root.bn.running_mean", "layer3.root.bn.running_var", "layer3.left_tree.conv1.weight", "layer3.left_tree.bn1.weight", "layer3.left_tree.bn1.bias", "layer3.left_tree.bn1.running_mean", "layer3.left_tree.bn1.running_var", "layer3.left_tree.conv2.weight", "layer3.left_tree.bn2.weight", "layer3.left_tree.bn2.bias", "layer3.left_tree.bn2.running_mean", "layer3.left_tree.bn2.running_var", "layer3.left_tree.shortcut.0.weight", "layer3.left_tree.shortcut.1.weight", "layer3.left_tree.shortcut.1.bias", "layer3.left_tree.shortcut.1.running_mean", "layer3.left_tree.shortcut.1.running_var", "layer3.right_tree.conv1.weight", "layer3.right_tree.bn1.weight", "layer3.right_tree.bn1.bias", "layer3.right_tree.bn1.running_mean", "layer3.right_tree.bn1.running_var", "layer3.right_tree.conv2.weight", "layer3.right_tree.bn2.weight", "layer3.right_tree.bn2.bias", "layer3.right_tree.bn2.running_mean", "layer3.right_tree.bn2.running_var", "layer4.root.conv.weight", "layer4.root.bn.weight", "layer4.root.bn.bias", "layer4.root.bn.running_mean", "layer4.root.bn.running_var", "layer4.left_tree.root.conv.weight", "layer4.left_tree.root.bn.weight", "layer4.left_tree.root.bn.bias", "layer4.left_tree.root.bn.running_mean", "layer4.left_tree.root.bn.running_var", "layer4.left_tree.left_tree.conv1.weight", "layer4.left_tree.left_tree.bn1.weight", "layer4.left_tree.left_tree.bn1.bias", "layer4.left_tree.left_tree.bn1.running_mean", "layer4.left_tree.left_tree.bn1.running_var", "layer4.left_tree.left_tree.conv2.weight", "layer4.left_tree.left_tree.bn2.weight", "layer4.left_tree.left_tree.bn2.bias", "layer4.left_tree.left_tree.bn2.running_mean", "layer4.left_tree.left_tree.bn2.running_var", "layer4.left_tree.left_tree.shortcut.0.weight", "layer4.left_tree.left_tree.shortcut.1.weight", "layer4.left_tree.left_tree.shortcut.1.bias", "layer4.left_tree.left_tree.shortcut.1.running_mean", "layer4.left_tree.left_tree.shortcut.1.running_var", "layer4.left_tree.right_tree.conv1.weight", "layer4.left_tree.right_tree.bn1.weight", "layer4.left_tree.right_tree.bn1.bias", "layer4.left_tree.right_tree.bn1.running_mean", "layer4.left_tree.right_tree.bn1.running_var", "layer4.left_tree.right_tree.conv2.weight", "layer4.left_tree.right_tree.bn2.weight", "layer4.left_tree.right_tree.bn2.bias", "layer4.left_tree.right_tree.bn2.running_mean", "layer4.left_tree.right_tree.bn2.running_var", "layer4.right_tree.root.conv.weight", "layer4.right_tree.root.bn.weight", "layer4.right_tree.root.bn.bias", "layer4.right_tree.root.bn.running_mean", "layer4.right_tree.root.bn.running_var", "layer4.right_tree.left_tree.conv1.weight", "layer4.right_tree.left_tree.bn1.weight", "layer4.right_tree.left_tree.bn1.bias", "layer4.right_tree.left_tree.bn1.running_mean", "layer4.right_tree.left_tree.bn1.running_var", "layer4.right_tree.left_tree.conv2.weight", "layer4.right_tree.left_tree.bn2.weight", "layer4.right_tree.left_tree.bn2.bias", "layer4.right_tree.left_tree.bn2.running_mean", "layer4.right_tree.left_tree.bn2.running_var", "layer4.right_tree.right_tree.conv1.weight", "layer4.right_tree.right_tree.bn1.weight", "layer4.right_tree.right_tree.bn1.bias", "layer4.right_tree.right_tree.bn1.running_mean", "layer4.right_tree.right_tree.bn1.running_var", "layer4.right_tree.right_tree.conv2.weight", "layer4.right_tree.right_tree.bn2.weight", "layer4.right_tree.right_tree.bn2.bias", "layer4.right_tree.right_tree.bn2.running_mean", "layer4.right_tree.right_tree.bn2.running_var", "layer5.root.conv.weight", "layer5.root.bn.weight", "layer5.root.bn.bias", "layer5.root.bn.running_mean", "layer5.root.bn.running_var", "layer5.left_tree.root.conv.weight", "layer5.left_tree.root.bn.weight", "layer5.left_tree.root.bn.bias", "layer5.left_tree.root.bn.running_mean", "layer5.left_tree.root.bn.running_var", "layer5.left_tree.left_tree.conv1.weight", "layer5.left_tree.left_tree.bn1.weight", "layer5.left_tree.left_tree.bn1.bias", "layer5.left_tree.left_tree.bn1.running_mean", "layer5.left_tree.left_tree.bn1.running_var", "layer5.left_tree.left_tree.conv2.weight", "layer5.left_tree.left_tree.bn2.weight", "layer5.left_tree.left_tree.bn2.bias", "layer5.left_tree.left_tree.bn2.running_mean", "layer5.left_tree.left_tree.bn2.running_var", "layer5.left_tree.left_tree.shortcut.0.weight", "layer5.left_tree.left_tree.shortcut.1.weight", "layer5.left_tree.left_tree.shortcut.1.bias", "layer5.left_tree.left_tree.shortcut.1.running_mean", "layer5.left_tree.left_tree.shortcut.1.running_var", "layer5.left_tree.right_tree.conv1.weight", "layer5.left_tree.right_tree.bn1.weight", "layer5.left_tree.right_tree.bn1.bias", "layer5.left_tree.right_tree.bn1.running_mean", "layer5.left_tree.right_tree.bn1.running_var", "layer5.left_tree.right_tree.conv2.weight", "layer5.left_tree.right_tree.bn2.weight", "layer5.left_tree.right_tree.bn2.bias", "layer5.left_tree.right_tree.bn2.running_mean", "layer5.left_tree.right_tree.bn2.running_var", "layer5.right_tree.root.conv.weight", "layer5.right_tree.root.bn.weight", "layer5.right_tree.root.bn.bias", "layer5.right_tree.root.bn.running_mean", "layer5.right_tree.root.bn.running_var", "layer5.right_tree.left_tree.conv1.weight", "layer5.right_tree.left_tree.bn1.weight", "layer5.right_tree.left_tree.bn1.bias", "layer5.right_tree.left_tree.bn1.running_mean", "layer5.right_tree.left_tree.bn1.running_var", "layer5.right_tree.left_tree.conv2.weight", "layer5.right_tree.left_tree.bn2.weight", "layer5.right_tree.left_tree.bn2.bias", "layer5.right_tree.left_tree.bn2.running_mean", "layer5.right_tree.left_tree.bn2.running_var", "layer5.right_tree.right_tree.conv1.weight", "layer5.right_tree.right_tree.bn1.weight", "layer5.right_tree.right_tree.bn1.bias", "layer5.right_tree.right_tree.bn1.running_mean", "layer5.right_tree.right_tree.bn1.running_var", "layer5.right_tree.right_tree.conv2.weight", "layer5.right_tree.right_tree.bn2.weight", "layer5.right_tree.right_tree.bn2.bias", "layer5.right_tree.right_tree.bn2.running_mean", "layer5.right_tree.right_tree.bn2.running_var", "layer6.root.conv.weight", "layer6.root.bn.weight", "layer6.root.bn.bias", "layer6.root.bn.running_mean", "layer6.root.bn.running_var", "layer6.left_tree.conv1.weight", "layer6.left_tree.bn1.weight", "layer6.left_tree.bn1.bias", "layer6.left_tree.bn1.running_mean", "layer6.left_tree.bn1.running_var", "layer6.left_tree.conv2.weight", "layer6.left_tree.bn2.weight", "layer6.left_tree.bn2.bias", "layer6.left_tree.bn2.running_mean", "layer6.left_tree.bn2.running_var", "layer6.left_tree.shortcut.0.weight", "layer6.left_tree.shortcut.1.weight", "layer6.left_tree.shortcut.1.bias", "layer6.left_tree.shortcut.1.running_mean", "layer6.left_tree.shortcut.1.running_var", "layer6.right_tree.conv1.weight", "layer6.right_tree.bn1.weight", "layer6.right_tree.bn1.bias", "layer6.right_tree.bn1.running_mean", "layer6.right_tree.bn1.running_var", "layer6.right_tree.conv2.weight", "layer6.right_tree.bn2.weight", "layer6.right_tree.bn2.bias", "layer6.right_tree.bn2.running_mean", "layer6.right_tree.bn2.running_var", "linear.weight", "linear.bias".
Unexpected key(s) in state_dict: "module.base.0.weight", "module.base.1.weight", "module.base.1.bias", "module.base.1.running_mean", "module.base.1.running_var", "module.base.1
如下图所示:
2. 错误意思指:
表明加载模型时参数字典中state_dict[]缺失了一些键,如"base.0.weight", “base.1.weight”, “base.1.bias"等键,出现了一些不必要的键,如"module.base.0.weight”
3. 原因:
模型训练时使用了多张GPU并行训练,出现下面几条语句:
model = torch.nn.DataParallel(model)
cudnn.benchmark = True
从而使训练好后保存的模型参数键值对中键开头多出现了"module."字符串,
4.解决方法:将不希望出现的键删除,将缺失的键添加进来,也即是将dict[key,value]键值对中的key全部去掉“module.”前缀,
具体代码如下:
model_cifar = SimpleDLA()
checkpoint = torch.load("pytorch_model.pth", map_location="cpu")['net']
print("key:",checkpoint.keys())
for key in list(checkpoint.keys()):
if 'module.' in key:
checkpoint[key[7:]] = checkpoint[key] #全部key去掉“module.”前缀
del checkpoint[key]
print("key2:",checkpoint.keys())
model_cifar.load_state_dict(checkpoint)
#下面这段代码也正确:
model = C3D_model.C3D(num_classes=101)
checkpoint = torch.load('run/run_10/models/C3D-ucf101_epoch-99.pth.tar', map_location=lambda storage, loc: storage)
state_dict = model.state_dict()
for k1, k2 in zip(state_dict.keys(), checkpoint.keys()):
state_dict[k1] = checkpoint[k2]
model.load_state_dict(state_dict)
state_dict = {
'net': net.state_dict(),
'acc': acc,
'epoch': epoch,
}
torch.save(state_dict, 'pytorch_model.pth')
在加载权重参数之前,初始化网络时,使用并行初始化网络,然后使用net.module.load_state_dict并行加载保存下来的网络权重参数:
net = nn.DataParallel(net,device_ids = devices (你的所有gpu)).to(devices[0])
net.module.load_state_dict(torch.load('pretrained.params'))
具体代码如下:
def load_pretrained_model(pretrained_model,num_hiddens,ffn_num_hiddens,num_heads,num_layers,dropout,max_len,devices):
data_dir = d2l.torch.download_extract(pretrained_model)
vocab = d2l.torch.Vocab()
vocab.idx_to_token = json.load(open(os.path.join(data_dir,'vocab.json')))
vocab.token_to_idx = {token:idx for idx,token in enumerate(vocab.idx_to_token)}
bert = d2l.torch.BERTModel(len(vocab),num_hiddens=num_hiddens,norm_shape=[256],ffn_num_input=256,ffn_num_hiddens=ffn_num_hiddens,num_heads=num_heads,num_layers=num_layers,dropout=dropout,max_len=max_len,key_size=256,query_size=256,value_size=256,hid_in_features=256,mlm_in_features=256,nsp_in_features=256)
bert = nn.DataParallel(bert,device_ids=devices).to(devices[0])
bert.module.load_state_dict(torch.load(os.path.join(data_dir,'pretrained.params')))
#bert.load_state_dict(torch.load(os.path.join(data_dir,'pretrained.params')))
return bert,vocab
devices = d2l.torch.try_all_gpus()
bert,vocab = load_pretrained_model('bert.small',num_hiddens=256,ffn_num_hiddens=512,num_heads=4,num_layers=2,dropout=0.1,max_len=512,devices=devices)
AttributeError: ‘DataParallel’ object has no attribute ‘xxx’
Fine tuning resnet: ‘DataParallel’ object has no attribute ‘fc’