[记录]PyTorch加载预训练权重时可能存在的问题

本文主要记录下我在复现CNN经典模型中加载官方预训练权重需要注意的点以及常见的错误。

1、常规加载预训练权重

        本节所涉及的方法必须保证模型中每一层的名字与预训练权重的对应层名字相同

        法1:

        weights=torch.load(opt.weight,map_location=device)
        net.load_state_dict(weights)

        opt.weight是预训练权重的路径;

        法2:

net=model.to(device)
ckpt=torch.load(weight_path, map_location=device)
ckpt={k: v for k, v in ckpt.items() if net.state_dict()[k].numel() == v.numel()}

        (个人感觉)如果没有对模型做任何修改,法一更加简洁;

2、复现代码后加载预训练权重可能存在的问题

        本节的前提是不修改任意一层的结构(即网络整体的结构是不变的)。相信不少小伙伴和我一样在自己复现完代码之后,总会在加载预训练权重报错。出现这种情况无外乎两种原因:搭建的网络的层名字与预训练权重名字不对应;搭建的网络中出现了多余的层(例如,官方实现时卷积层没有偏置,而你实现过程中加了偏置;或者是搭建的网络中出现了冗余即空的列表)。

        针对名字不对应的情况,可以通过以下代码完成权重加载:

        

from collections import OrderedDict
ckpt=torch.load(weight_path, map_location=device)
net=model.to(device)    #加载模型
module_lst=[i for i in net.state_dict()]    #搭建网络的所有层的名字
weights=OrderedDict()    #创建一个有序字典,存放权重
for idx,(k,v) in enumerate(ckpt.items()):
   if net.state_dict()[module_lst[idx]].numel()==v.numel():    
        weights[module_lst[idx]]=v    #如果对应层参数量相同,保存当前层以及对应参数
net.load_state_dict(weights, strict=False)

        如果上述方式仍然报错的话,就需要检查我们搭建的网络的层数是否与预训练权重的相同(我在检查时候通常会保存到两个txt文件中方便对比;如果整体结构没修改的话,问题通常会出现在一些小的点,例如多加了偏置或者出现了冗余的层),下面就我在复现EfficientNet时加载预训练权重出现的问题做简要说明。

        下图是预训练权重对应层以及参数以及我搭建的模型对应的参数:

[记录]PyTorch加载预训练权重时可能存在的问题_第1张图片

[记录]PyTorch加载预训练权重时可能存在的问题_第2张图片

        通过对比两个字典,发现元素个数不相同=>我们的模型参数数量要多于预训练权重。仔细一看发现我们模型的每个卷积层都多了一行(我也不太清楚为什么会多一行,并且里面没有任何参数):

         通过下面代码将模型中所有的num_batches_tracked删除掉:

ckpt=torch.load(opt.weight,map_location=device)
weights = collections.OrderedDict()
net_layer=net.state_dict()
new_model_dict = {}
for layer_name,val in net_layer.items():
   if "num_batches_tracked" in layer_name:
        pass
   else:
        new_model_dict[layer_name] = val
module_lst = [i for i in new_model_dict]
for idx, (k, v) in enumerate(ckpt.items()):
    if net.state_dict()[module_lst[idx]].numel() == v.numel():
        weights[module_lst[idx]] = v
net.load_state_dict(weights,strict=False)

        经过处理之后,发现参数量与预训练权重的相同。如果处理之后仍然不相同那就要去细心比对了(通常是由于偏置引起的)   

[记录]PyTorch加载预训练权重时可能存在的问题_第3张图片

        关于更改网络结构后预训练权重如何加载以及能否使用预训练权重,由于目前未涉及到该方面,后续需要的话会做记录。    

 

你可能感兴趣的:(pytorch,python,深度学习)