RuntimeError: Error(s) in loading state_dict for DataParallel: size mismatch for module.fcc.weight:

问题描述:

        pytorch代码,加载预训练模型时报错,分类类别数不一致

报错信息:

错误代码:

checkpoint = torch.load('pretrain.pth', map_location=device)
model = nn.DataParallel(model)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model.to(device)

尝试: 

checkpoint = torch.load('pretrain.pth', map_location=device)

del_keys = ['module.fcc.weight', ' module.fcc.bias', ' module.head.weight', 'module.head.bias']
for k in del_keys:
    del checkpoint[k]

model = nn.DataParallel(model)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model.to(device)

报错:


 

 Debug发现checkpoint['model_state_dict'']才是要删除的预训练模型对应的字典

In[3]: checkpoint.keys()
Out[3]: dict_keys(['iter', 'model_state_dict'])

In[3]: checkpoint['model_state_dict'].keys()
Out[4]: odict_keys([ 'module.bn1.weight', 'module.bn1.bias', ......, 'module.head.weight', 'module.head.bias', ......])

则解决方案为:

checkpoint = torch.load('pretrain.pth', map_location=device)

del_key = []
for key, _ in checkpoint['model_state_dict'].items():
    if 'fcc' in key:
        del_key.append(key)
    elif 'head' in key:
        del_key.append(key)
    else:
        pass
for key in del_key:
    del checkpoint['model_state_dict'][key]

model = nn.DataParallel(model)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model.to(device)

你可能感兴趣的:(学习路上的问题,深度学习,人工智能)