Pytorch只加载部分参数权重 load (pth文件) & 加载模型不完全匹配&module.后缀问题

当使用Pytorch作为学习框架时:

我们已经建立好自己的model,但是想加载他人的网络模型架构时,总会出现一些问题,其中最为常见的三种问题

1. 网络每层的名称不相同,导致这个问题大体有两种原因

1.1 网络本身命名不同

1.2 有些网络在多块GPU上进行训练会使得每一层网络名称多一个module.的后缀,而有的只在CPU上训练则没有module.这个后缀。

针对上面两种1.1 与1.2 我们需要弄清下面这个函数的原理

model.load_state_dict()

该函数是指当我们已经构造好了一个模型后,可能要加载一些训练好的模型参数。举例子如下:

假设  trained.pth 是一个训练好的网络的模型参数存储载体。model = Net()是我们刚刚生成的一个新模型,我们希望model将trained.pth中的参数加载加载进来,这时我们就需要采用上述函数作为加载函数。

而这个函数加载的数据类型是OrderedDict,所谓OrderedDict,就是一个有着固定顺序的字典,当我们想要自己创建一个OrderedDict时需要按照下述引用OrderedDict。

from collections import OrderedDict

以1.2为例,如果我们想要将待加载模型trained.pth中的层的名字中的module.这个后缀去掉,则需采用以下代码。

new_state_dict = OrderedDict()
for k in 待加载的网络参数的位置(也是OrderedDict类型):
    name = k.replace('module.', '')
    new_state_dict[name] = checkpoint['model_state_dict'].setdefault(k)
self.model.load_state_dict(new_state_dict)

2. 加载模型不完全匹配

当加载模型不完全匹配时,我们可以采用

model.load_state_dict(new_state_dict,strict=False)

其中的strict参数设置为false则使得网络不完全加载目标网络参数,为True则完全粘贴过来,但凡有一点不一致都会报错。

你可能感兴趣的:(pytorch,深度学习)