加载预训练权重时不匹配

场景

复现Rethinking the Learning Paradigm for Dynamic Facial Expression Recognition这篇论文时,加载已经训练好的.pt文件进行推理,发现准确率很低。利用下面两行代码加载预训练的权重:

weights_dict = torch.load('/data2/liuxu/attribute/M3DFEL/outputs/DFEW-[10-29]-[14:29]/model_best.pth', map_location='cuda:7')
mymodel.load_state_dict(weights_dict, strict=False)#model.load_state_dict()中的strict = False,但是这是适用于权重参数字典中key不相同的情况

提示如下:

_IncompatibleKeys(missing_keys=[‘features.0.0.weight’, ‘features.0.1.weight’, ‘features.0.1.bias’, ‘features.0.1.running_mean’, ‘features.0.1.running_var’, ‘features.1.0.conv1.0.weight’, ‘features.1.0.conv1.1.weight’, ‘features.1.0.conv1.1.bias’, ‘features.1.0.conv1.1.running_mean’, ‘features.1.0.conv1.1.running_var’, ‘features.1.0.conv2.0.weight’, ‘features.1.0.conv2.1.weight’, ‘features.1.0.conv2.1.bias’, ‘features.1.0.conv2.1.running_mean’, ‘features.1.0.conv2.1.running_var’, ‘features.1.1.conv1.0.weight’, ‘features.1.1.conv1.1.weight’, ‘features.1.1.conv1.1.bias’, ‘features.1.1.conv1.1.running_mean’, ‘features.1.1.conv1.1.running_var’, ‘features.1.1.conv2.0.weight’, ‘features.1.1.conv2.1.weight’, ‘features.1.1.conv2.1.bias’, ‘features.1.1.conv2.1.running_mean’, ‘features.1.1.conv2.1.running_var’, ‘features.2.0.conv1.0.weight’, ‘features.2.0.conv1.1.weight’, ‘features.2.0.conv1.1.bias’, ‘features.2.0.conv1.1.running_mean’, ‘features.2.0.conv1.1.running_var’, ‘features.2.0.conv2.0.weight’, ‘features.2.0.conv2.1.weight’, ‘features.2.0.conv2.1.bias’, ‘features.2.0.conv2.1.running_mean’, ‘features.2.0.conv2.1.running_var’, ‘features.2.0.downsample.0.weight’, ‘features.2.0.downsample.1.weight’, ‘features.2.0.downsample.1.bias’, ‘features.2.0.downsample.1.running_mean’, ‘features.2.0.downsample.1.running_var’, ‘features.2.1.conv1.0.weight’, ‘features.2.1.conv1.1.weight’, ‘features.2.1.conv1.1.bias’, ‘features.2.1.conv1.1.running_mean’, ‘features.2.1.conv1.1.running_var’, ‘features.2.1.conv2.0.weight’, ‘features.2.1.conv2.1.weight’, ‘features.2.1.conv2.1.bias’, ‘features.2.1.conv2.1.running_mean’, ‘features.2.1.conv2.1.running_var’, ‘features.3.0.conv1.0.weight’, ‘features.3.0.conv1.1.weight’, ‘features.3.0.conv1.1.bias’, ‘features.3.0.conv1.1.running_mean’, ‘features.3.0.conv1.1.running_var’, ‘features.3.0.conv2.0.weight’, ‘features.3.0.conv2.1.weight’, ‘features.3.0.conv2.1.bias’, ‘features.3.0.conv2.1.running_mean’, ‘features.3.0.conv2.1.running_var’, ‘features.3.0.downsample.0.weight’, ‘features.3.0.downsample.1.weight’, ‘features.3.0.downsample.1.bias’, ‘features.3.0.downsample.1.running_mean’, ‘features.3.0.downsample.1.running_var’, ‘features.3.1.conv1.0.weight’, ‘features.3.1.conv1.1.weight’, ‘features.3.1.conv1.1.bias’, ‘features.3.1.conv1.1.running_mean’, ‘features.3.1.conv1.1.running_var’, ‘features.3.1.conv2.0.weight’, ‘features.3.1.conv2.1.weight’, ‘features.3.1.conv2.1.bias’, ‘features.3.1.conv2.1.running_mean’, ‘features.3.1.conv2.1.running_var’, ‘features.4.0.conv1.0.weight’, ‘features.4.0.conv1.1.weight’, ‘features.4.0.conv1.1.bias’, ‘features.4.0.conv1.1.running_mean’, ‘features.4.0.conv1.1.running_var’, ‘features.4.0.conv2.0.weight’, ‘features.4.0.conv2.1.weight’, ‘features.4.0.conv2.1.bias’, ‘features.4.0.conv2.1.running_mean’, ‘features.4.0.conv2.1.running_var’, ‘features.4.0.downsample.0.weight’, ‘features.4.0.downsample.1.weight’, ‘features.4.0.downsample.1.bias’, ‘features.4.0.downsample.1.running_mean’, ‘features.4.0.downsample.1.running_var’, ‘features.4.1.conv1.0.weight’, ‘features.4.1.conv1.1.weight’, ‘features.4.1.conv1.1.bias’, ‘features.4.1.conv1.1.running_mean’, ‘features.4.1.conv1.1.running_var’, ‘features.4.1.conv2.0.weight’, ‘features.4.1.conv2.1.weight’, ‘features.4.1.conv2.1.bias’, ‘features.4.1.conv2.1.running_mean’, ‘features.4.1.conv2.1.running_var’, ‘lstm.weight_ih_l0’, ‘lstm.weight_hh_l0’, ‘lstm.bias_ih_l0’, ‘lstm.bias_hh_l0’, ‘lstm.weight_ih_l0_reverse’, ‘lstm.weight_hh_l0_reverse’, ‘lstm.bias_ih_l0_reverse’, ‘lstm.bias_hh_l0_reverse’, ‘lstm.weight_ih_l1’, ‘lstm.weight_hh_l1’, ‘lstm.bias_ih_l1’, ‘lstm.bias_hh_l1’, ‘lstm.weight_ih_l1_reverse’, ‘lstm.weight_hh_l1_reverse’, ‘lstm.bias_ih_l1_reverse’, ‘lstm.bias_hh_l1_reverse’, ‘to_qkv.weight’, ‘norm.weight’, ‘norm.bias’, ‘norm.mean_weight’, ‘norm.var_weight’, ‘pwconv.weight’, ‘pwconv.bias’, ‘fc.weight’, ‘fc.bias’], unexpected_keys=[‘epoch’, ‘state_dict’, ‘best_wa’, ‘best_ua’, ‘optimizer’, ‘args’])

原因探究与解决方案

查看预训练的权重字典的键值发现,这些键对于模型来说都是多余的

weights_dict.keys()

dict_keys([‘epoch’, ‘state_dict’, ‘best_wa’, ‘best_ua’, ‘optimizer’, ‘args’])

进一步探究发现,模型中不兼容的键都在'state_dict'中,我们通过weights_dict["state_dict"]提取权重字典中与模型适配的键值,修改后的代码如下:

weights_dict = torch.load('/data2/liuxu/attribute/M3DFEL/outputs/DFEW-[10-29]-[14:29]/model_best.pth', map_location='cuda:7')
mymodel.load_state_dict(weights_dict["state_dict"], strict=False)

此时我们发现预训练的模型可以很好的进行推理了!!!

device = "cuda:7"
mymodel.to(device)

test_dataloader = create_dataloader(args, "test")

mymodel.eval()

all_pred, all_target = [], []
# 模型推理
for i, (images, target) in enumerate(test_dataloader):
    images = images.to(device)
    target = target.to(device)
    with torch.no_grad():
        output = mymodel(images)
    pred = torch.argmax(output, 1).cpu().detach().numpy()
    target = target.cpu().numpy()
    print(pred,target)
    break

[1 2 1 2 3 1 2 2 4 4 3 6 1 2 3 6] [1 2 1 3 3 1 6 4 4 2 3 6 0 4 3 0]

你可能感兴趣的:(深度学习,人工智能,深度学习,神经网络,python)