学习笔记26-解决:载入预训练模型时Pytorch遇到权重不匹配的问题(附+修改后的预训练模型载入和冻结特征权重完整代码)

在pytorch微调mobilenetV3模型时遇到的问题
1.KeyError: ‘features.4.block.2.fc1.weight’
这个是因为模型结构修改了,没有正确修改预训练权重,导致载入权重与模型不同,使用下面说的两种方法适当修改载入权重即可。
2.size mismatch for fc.weight: copying a param with shape torch.Size([1000, 1280]) from checkpoint, the shape in current model is torch.Size([4, 1280]).
size mismatch for fc.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([4]).
下载的预训练模型中的全连接层是1000类别的,而本人的类别只有4类,所以会报不匹配的错误。

因此我就学习了PyTorch载入预训练权重方法
方法一 :自称为万能法,直接删掉分类层,直接避免最后全连接层权重不匹配问题。

net = MobileNetV3(num_class=4)
net_weights = net.state_dict()
model_weights_path = "model/mobilenet_v3.pth"
pre_weights = torch.load(model_weights_path)
# delete classifier weights
# 这种方法主要是遍历字典,.pth文件(权重文件)的本质就是字典的存储
# 通过改变我们载入的权重的键值对,可以和当前的网络进行配对的
# 这里举到的例子是对"classifier"结构层的键值对剔除,或者说是不载入该模块的训练权重,这里的"classifier"结构层就是最后一部分分类层
pre_dict = {k: v for k, v in pre_weights.items() if "classifier" not in k}
# 如果修改了载入权重或载入权重的结构和当前模型的结构不完全相同,需要加strict=False,保证能够权重载入
net.load_state_dict(pre_dict, strict=False)
net.to(device)

方法二 :进行两种权重对比,会减少问题存在,但有的时候还是会出现问题。

# 另一种方法会直接两种权重对比,直接两种方法对比,减少问题的存在
net = MobileNetV3(num_class=4)
net_weights = net.state_dict()
model_weights_path = "model/mobilenet_v3.pth"
pre_weights = torch.load(model_weights_path)
# 通过改变我们载入的权重的键值对,可以和当前的网络进行配对的
pre_dict = {k: v for k, v in pre_weight.items() 
			if net_weights[k].numel() == v.numel()}
# 在下载的官方预训练参数中,num_classes=1000 而在我们的model中num_classes=4
#net.state_dict()[k].numel(): 提取model模型中的关键字K代表的层的长度
#v.numel: 是下载的预训练参数中对应层的长度
# 如果修改了载入权重或载入权重的结构和当前模型的结构不完全相同,需要加strict=False,保证能够权重载入
net.load_state_dict(pre_dict, strict=False)
net.to(device)

利用方法一解决自己权重不匹配问题,最后载入预训练模型和冻结特征权重完整代码如下所示:

# create model
    net = mobilenet_v6_large(num_classes=4)
    # load pretrain weights 加载一个预训练模型,收敛很快
    model_weight_path = "model/mobilenet_v3.pth" #后期可以使用自己数据集的训练权重作为预训练模型,当为老师,实现半监督(知识蒸馏)
    #断言判断pth文件在不在
    assert os.path.exists(model_weight_path), "file {} dose not exist.".format(model_weight_path)

    pre_weights = torch.load(model_weight_path, map_location=device)
    # 这里是对"classifier"结构层的键值对剔除,简单理解就是不载入该模块的训练权重
    pre_dict = {k: v for k, v in pre_weights.items() if "classifier" not in k}
    missing_keys, unexpected_keys = net.load_state_dict(pre_dict, strict=False)
    
    # freeze features weights  #冻结特征权重
    for param in net.features.parameters():
        # param.requires_grad = True  #一起训练
        param.requires_grad = False   #迁移学习,固定他的特征提取层,优化他的全连接分类层

    net.to(device)# 可以指定CPU或者GPU,和cuda()区别是cuda()只能指定GPU

你可能感兴趣的:(学习笔记,pytorch,学习,深度学习)