PyTorch 加载预训练权重

前言

 使用PyTorch官方提供的权重或者其他第三方提供的权重对相同模型的参数进行初始化,在数据量较少的前提下,可以帮助模型更快地收敛到最优点,达到更好的效果,即迁移学习。

 在大部分的迁移学习场景中,我们一般沿用之前模型的相关参数,这是因为卷积神经网络认为大部分的特征提取模式是一致的,即卷积神经网络中的归纳偏置能力强。在使用别人训练好的权重的过程中,一般冻结/保留提供权重模型中浅层的权重参数,只修改跟当前任务相关的层数。

 在本文中,主要讲解如何修改提供的权重并将其迁移到当前的任务上,例如:如何将PyTorch官方提供的ResNet权重迁移到别的分类任务上。本文将围绕此需求进行相关方法的介绍。

一、PyTorch官方ResNet权重下载链接

 PyTorch提供的ResNet权重文件下载链接如下:

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}

本文通过ResNet网络使用num_classes代表该任务的分类类别,例如要将PyTorch官方在ImageNet数据集上训练的参数迁移到一个二分类的小任务中,即num_classes=2

 其他的网络预训练权重下载链接可以查看torchvision.models下对应模型的python文件或在官网进行查找。

  PyTorch官方提供的两种加载权重方法说明如下:

二、方法1

 第一种方法是:根据模型结构,删除权重模型最后的分类层,替换成属于该分类任务对应的全连接层,如将ImageNet最后输出的1000个神经元(1000分类)替换成2个输出神经元(二分类)。具体代码如下:

# 使用官方给定的api进行网络结构和权重的加载
from torchvision.models.resnet import resnet18
net = resnet18(pretrained=True)   
# 冻结所有参数 使其不更新
for param in net.parameters():
    param.requires_grad = False
# 替换全连接层
in_channel = net.fc.in_features    # 获得原模型全连接层的输入特征大小
net.fc = nn.Linear(in_channel, num_classes)  # num_classes代表分类器的类别

这里先设置所有参数为不可训练,而新的nn.Linear是可训练的。

 如果不使用官方定义的模型结构,也可以使用自己定义好的,前提是自身定义的模型在网络结构的定义上跟官方是一致的(层的名称和网络的参数),自己定义的模型加载预训练权重的方式如下:

# 使用自定义的网络结构和权重的加载
net = resnet18(num_classes=1000)      # ***
# 载入预训练权重 model_weight_path:含权重名的模型路径 device:设备
net.load_state_dict(torch.load(model_weight_path, map_location=device))
# 冻结所有参数 使其不更新
for param in net.parameters():
    param.requires_grad = False
# 替换全连接层
in_channel = net.fc.in_features    # 获得原模型全连接层的输入特征大小
net.fc = nn.Linear(in_channel, num_classes)  # num_classes代表分类器的类别

注意:这里是在原有的模型基础上进行增删网络结构的操作,如想一开始就定义针对该任务的方式请看方法2。

三、方法2

 第二种方式是:首先读取权重到变量中但不载入(读取的参数是以一个字典的形式存储),然后找到预训练权重中不需要的权重将它删除,最后再载入到网络中。

from torchvision.models.resnet import resnet18
net = resnet18(pretrained=True) 
# 将定义的网络中参数读取到net_weights字典变量中 (key: name, val: weights)
net_weights = net.state_dict()    
# 读取预训练的权重
weights = torch.load(model_weight_path, map_location=device)    
del_key = []

# 遍历预训练权重,如果存在key值包含“fc”的权重就将它从预训练权重字典中删除
for key, _ in weights.items():
    if "fc" in key:
        del_key.append(key)
for key in del_key:
    del weights[key]

# 载入权重
net.load_state_dict(weights, strict=False)

注意这里要设置strict=False。默认为True会严格按key值载入权重,如果出现缺失的权重会报错,这里因为我们已经删除了一部分预训练权重所以要设置为False。但需要注意的是使用strict=False时,即使模型和权重的关系不对应,载入也不会报错,如将ResNet34上的预训练权重以此方式载入到ResNet18模型时,会根据网络结构对可加载的权重进行加载,如果这样设置有可能会造成预训练权重没有加载到要求的网络中,造成无效地预训练/迁移。

 使用自定义结构进行权重加载的方式如下:

net = resnet18(num_classes=2)    # ***
# 将定义的网络中参数读取到net_weights字典变量中 (key: name, val: weights)
net_weights = net.state_dict()    
# 读取预训练的权重
weights = torch.load(model_weight_path, map_location=device)    
del_key = []

# 遍历预训练权重,如果存在key值包含“fc”的权重就将它从预训练权重字典中删除
for key, _ in weights.items():
    if "fc" in key:
        del_key.append(key)
for key in del_key:
    del weights[key]

# 载入权重
net.load_state_dict(weights, strict=False)

 还有比较麻烦的情况就是定义的网络中某一层的key值和预训练权重对应层的key值不一样,这个比较麻烦,一般需要手动修改字典中的key值,本文不再详述。

四、参考链接

[1] https://www.csdn.net/tags/MtTaEgzsMTk3NjAxLWJsb2cO0O0O.html

[2] https://pytorch.org/vision/stable/models.html

你可能感兴趣的:(笔记,pytorch,深度学习,python)