pytorch载入部分预训练权重

文章目录

  • 前言
  • 方法一
  • 方法二


前言

使用迁移学习的方法训练网络往往需要载入部分已训练好的网络权重,接下来介绍两种载入预训练权重的方法,第一种比较简单,第二种方法稍微复杂但是更加灵活。

方法一

先按原模型载入全部权重,再替换掉需要重新训练的层

	net = resnet34()  # 注意这里没有设置num_classes而是在新建全连接层的时候修改的
    # 载入预训练参数
    net.load_state_dict(torch.load(model_weight_path, map_location=device))
    # 设置所有参数不可训练
    for param in net.parameters():
        param.requires_grad = False
    # 替换全连接层
    in_channel = net.fc.in_features    # 获得原模型全连接诶层输入channel
    net.fc = nn.Linear(in_channel, 5)

这里先设置所有参数为不可训练,再用新的nn.Linear替换掉原来模型中的fc(全连接层),这样新的fc就是可训练的了。

方法二

还有种情况是,我们在初始化的时候就重新设置了网络最后的输出,像这样:

net = resnet34(5)   # 设置分类的类别为5  

我们可以先读取权重到变量中但不载入,读取的参数是以一个字典的形式存储。
找到预训练权重中我们不需要的权重将它删除,再载入到网络中。

net = resnet34(5)
    net_weights = net.state_dict()    # 查看定义的网络中每一层的权重(这里只是初始化的)和对应的名称
    pre_weights = torch.load(model_weight_path, map_location=device)    # 读取预训练的权重
    del_key = []

    # 遍历预训练权重,如果存在key值包含“fc”的权重就将它从预训练权重字典中删除
    for key, _ in pre_weights.items():
        if "fc" in key:
            del_key.append(key)

    for key in del_key:
        del pre_weights[key]
    # 载入权重
    net.load_state_dict(pre_weights, strict=False)

注意这里要设置strict=False。默认为True会严格按key值载入权重,如果出现缺失的权重会报错,这里因为我们已经删除了一部分预训练权重所以要设置为False。

还有比较麻烦的情况就是定义的网络中某一层的key值和预训练权重对应层的key值不一样,这就需要手动修改字典中的key值了

你可能感兴趣的:(python,pytorch,深度学习)