解决Pytorch下报错Missing key(s) in state_dict: "resnet.conv1.0.weight",和 Unexpected key(s) in state_dict

运行predict.py时报错如下:
RuntimeError: Error(s) in loading state_dict for VisitNet:
Missing key(s) in state_dict: “resnet.conv1.0.weight”, “resnet.conv1.1.weight”, “resnet.conv1.1.bias”, “resnet.conv1.1.running_mean”, “resnet.conv1.1.running_var”, “resnet.conv1.3.weight”, “resnet.conv1.4.weight”, “resnet.conv1.4.bias”, “resnet.conv1.4.running_mean”, “resnet.conv1.4.running_var”, “resnet.conv1.6.weight”, "

Unexpected key(s) in state_dict: “module.resnet.conv1.0.weight”, “module.resnet.conv1.1.weight”, “module.resnet.conv1.1.bias”, …

原因是训练时加入了

model = nn.DataParallel(model).cuda()

而测试时没有加入。

解决方法:
predict.py加入

model = nn.DataParallel(model).cuda()

即可,如下OK。

model = Net().cuda()
for name, param in model.named_parameters():
    print(name, param.shape)
model = nn.DataParallel(model).cuda()
model.load_state_dict(torch.load(model_path))

你可能感兴趣的:(bug,bug)