Convolutional Neural Network 的 PyTorch 实现(一)指定部分预训练权重加载


前言

迁移学习的方法被广泛应用于卷积神经网络,基于大数据集训练而得到的权重文件对数据具有强的特征提取能力,在此基础上针对特有数据集进行模型的二次训练(微调),能大大降低训练时长以及犯错成本。

当改变卷积神经网络模型结构后,原有的预训练权重将无法成功加载到已经改变了的模型中,以下提供了针对模型的修改,实现指定部分权重的加载。


代码如下:


    # load pretrain weights
    model_weight_path = "./resnet34_pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)

    # Official_Option
	# net = resnet34()
    # net.load_state_dict(torch.load(model_weight_path, map_location=device))
    # change fc layer structure
    # in_channel = net.fc.in_features
    # net.fc = nn.Linear(in_channel, 5)

    # The Other option
    net = resnet34(num_classes=100)
    net_State = net.state_dict()
    pre_weights = torch.load(model_weight_path, map_location=device)
    del_key = []
    
    for key, _ in pre_weights.items():
        if "fc" in key or "layer4" in key:
            del_key.append(key)
    for key in del_key:
        del pre_weights[key]
        
    missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
    print("[missing_keys]:", *missing_keys, sep="\n")
    print("[unexpected_keys]:", *unexpected_keys, sep="\n")

你可能感兴趣的:(笔记,pytorch,深度学习)