本文介绍如何在pytorch中载入模型的部分权重
, 总结了2个比较常见的问题:
比如在花卉数据集分类时只有
5类
,所以最后一层全连接层节点个数为5,但是我们载入的预训练权重是针对ImageNet-1k的权重,它的全连接层节点个数是1000
,很明显是不能直接
载入预训练模型权重的。
能不能载入部分权重呢?
当然这要看你对网络是如何修改的,如果你是在网络的高层进行结构的修改的话,那么相对底层的没有被修改过的权重还是可以载入的,因为底层都是比如Backbone都是比较通用
的权重,载入之后对我们的训练是很有帮助的。
以分类网络ResNet为例说明,对应项目中的load_weights.py
来介绍对部分权重进行载入。
import os
import torch
import torch.nn as nn
from model import resnet34
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load pretrain weights
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
model_weight_path = "./resnet34-pre.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
# option1
net = resnet34()
net.load_state_dict(torch.load(model_weight_path, map_location=device))
# change fc layer structure
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 5)
# option2
# net = resnet34(num_classes=5)
# pre_weights = torch.load(model_weight_path, map_location=device)
# del_key = []
# for key, _ in pre_weights.items():
# if "fc" in key:
# del_key.append(key)
#
# for key in del_key:
# del pre_weights[key]
#
# missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
# print("[missing_keys]:", *missing_keys, sep="\n")
# print("[unexpected_keys]:", *unexpected_keys, sep="\n")
if __name__ == '__main__':
main()
下载官方提供的ResNet34预训练模型, 并将它命名为resnet34-pre.pth
,接下来介绍官方提供的载入部分权重的方法。
num_classes
参数,此时默认的num_classes=1000,此时就可以直接载入官方的预训练权重。因为我们使用的是默认的全连接层个数1000,与预训练权重是一致的。# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
model_weight_path = "./resnet34-pre.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
# option1
net = resnet34()
net.load_state_dict(torch.load(model_weight_path, map_location=device))
5
,接下来该怎么办呢?首先查看resnet34模型搭建的源码。可以看到全连接层是通过sef.fc=nn.Linear(512*block.expansion,num_class)
这条语句实现的。nn.Linear
类,可以看到它有这么几个参数self.in_features
和self.out_features
,分别表示全连接层的输入和输出
的节点个数。对于imagenet-1k,输出节点个数self.out_features对应的就是1000. 因此我们可以通过fc.in_features
获得网络的输入节点个数,然后输出节点个数定义为我们自己的分类个数5
。net.fc=nn.Linear(in_channel,5)
通过创建新的全连接层来替换原来的全连接层。这样我们就变相的载入了Conv1
到layer4_x
的层结构,替换掉全连接层相当于没有载入全连接层权重,刚好符合我们的要求
net = resnet34(num_classes=5)
pre_weights = torch.load(model_weight_path, map_location=device)
del_key = []
for key, _ in pre_weights.items():
if "fc" in key:
del_key.append(key)
for key in del_key:
del pre_weights[key]
missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
print("[missing_keys]:", *missing_keys, sep="\n")
print("[unexpected_keys]:", *unexpected_keys, sep="\n")
num_classes
参数,也就是最后一个全连接层节点个数一开始就设置为5了。此时就不能像前一种方法一样直接通过net.load_state_dict(torch.load(model_weight_path, map_location=device))
载入预训练权重了。因为网络的全连接层节点个数和预训练模型是不一样的,直接载入就会报错
。我们应该怎么办呢?(torch.load(model_weight_path, map_location=device)
,先读取预训练权重保存为一个有序字典Orderedict
的形式。每个键值对对应一组参数和权重。resnet34
查看构建的代码,可以看到,其全连接层为self.fc
包含了fc
字段。除此之外,也可以通过实例化后的模型,调用state_dict()
函数,查看模型的所有模型权重的key和value值:net = resnet34(num_classes=5)
net_weights = net.state_dict()
fc.weight
和fc.bias
,此时我们可以遍历pre_weights的每个key值,如果key中包含有fc
这个字段我们就可以知道它是属于全连接层的权重,后续把包含fc的权重删除掉,然后我们再去载入剩下的权重。实例化的模型和载入的模型
,他们权重的名称(key值)要是一样的才可以载入和方便删减。还有一种情况可能载入模型的key与实例化的模型中的key
值不一样。那么这种情况的话就会比较麻烦点。那么就需要将载入模型的key值跟实例化一一对应,将载入模型的key改为实例化模型的key值
。这就需要你对网络搭建过程非常清楚,你要知道每个层它所对应的权重是什么,这样的话就可以编辑有序字典中的key来载入你想载入的权重。这个例子我们载入的权重和我们创建的模型它的key值都是一样的,因此相对于刚才说的这种情况,载入会比较简单些。fc
字段,我们就将这个key值先存到del_key
列表中。通过调试可以发现del_key
存的就是fc_weights
和fc_bias
。紧接着我们再遍历del_key依次将这些key从pre_weights
字典中删除。 pre_weights = torch.load(model_weight_path, map_location=device)
del_key = []
for key, _ in pre_weights.items():
if "fc" in key:
del_key.append(key)
for key in del_key:
del pre_weights[key]
strict=False
, 如果你不传的话,它默认是为True的。如果strict=True
它会严格的载入每个key值,因为我们删减掉全连接中的权重,因此就不能将strict设置为True。net.load_state_dict(pre_weights, strict=False)
会返回两个 变量,分别是missing_keys
和unexpected_keys
。
missing_key
:表示在我们实例化的模型net中有部分权重并没有在pre_weights
预训练权重中出现,就相当于与pre_weights
中漏掉了这些权重。unexpected_key
:就是说在我们载入的pre_weights
中有一部分权重它不在我们的net中,此时就会存在unexpected_keys
中。针对我们刚才讲的情况,应该会出现两个missing_key
:fc.weights和fc.bias: missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
print("[missing_keys]:", *missing_keys, sep="\n")
print("[unexpected_keys]:", *unexpected_keys, sep="\n")
执行以后打印的信息:>> [missing_keys]:
>> fc.weight
>> fc.bias
>> [unexpected_keys]:
可以看到missing_key
中有fc.weights和fc.bias,在unexpected_keys
中是没有任何参数的。也就时除了fc.weights和fc.bias两个全连接参数外,其他参数都载入进来了。如果有些人,除了
fc层外还改动了某些高层的结构如resnet中Conv5_x,我们如何去载入低层没有改动的权重呢?
: 此时对于resnet模型就需要载入除了Conv5_x
和fc
层之外的所有权重
此时我们可以在条件中,判断key是否包含layer4
,如果有的话也将它删掉。
pre_weights = torch.load(model_weight_path, map_location=device)
del_key = []
for key, _ in pre_weights.items():
if "fc" in key or "layer4" in key:
del_key.append(key)
for key in del_key:
del pre_weights[key]
执行之后,我们发现在missing_key
列表中除了我们之前两个全连接层权重之外,剩下,剩下的都是layer4所对应的权重,也就是说我们也没有将layer4所对应的权重载入进去。
以上介绍的是2种比较常见的载入部分权重的方法,除了我们讲到的在载入的权重的有序字典筛选之外,我们可以自己新创建一个字典,新创建一个字典之后,可以自己组建key,value然后用上文介绍的方法进行载入就可以了,这样的话会更加的灵活.
代码链接:https://pan.baidu.com/s/1j34QBVb9ZKxWX7d1Vm9QrQ?pwd=stxx
提取码:stxx