【Pytorch】修改网络后加载预训练权重

内容

本文章带大家如何给自己修改过后的网络,加载预训练权重。

很多小伙伴针对某一模型进行修改的时候,在修改模型后想要加载预训练权重,会发现频频报错,其实最主要原因就是权重的shape对应不上。

注意:以下方法仅仅针对于在原网络改动不大的情况下加载预训练权重!

1、.pt文件----->model:从.pt文件直接加载预训练权重。

# 模板
ckpt = torch.load(weights)  # 加载预训练权重

model = Model() # 创建我们的模型
model_dict = model.state_dict() # 得到我们模型的参数

# 判断预训练模型中网络的模块是否修改后的网络中也存在,并且shape相同,如果相同则取出
pretrained_dict = {k: v for k, v in ckpt.items() if k in model_dict and (v.shape == model_dict[k].shape)}

# 更新修改之后的 model_dict
model_dict.update(pretrained_dict)

# 加载我们真正需要的 state_dict
model.load_state_dict(model_dict, strict=False)

2、model 1 ------>model 2:获取一个模型的权重加载到另一个模型。

# 模板
import torchvision.models as models

# 创建model
# 类型 1 加载 经典模型 与 自己的模型
resnet50 = models.resnet50(pretrained=True)  # 创建预训练模型,并加载参数
model = Model()   # 创建我们的网络

# # 类型 2 加载 两个自己的模型
# ckpt = torch.load(weights)  # 加载预训练权重
# model_1 = Model_1()  # 创建预训练模型,并加载参数
# model_1.load_state_dict(ckpt, strict=False)
# model_2 = Model_2()   # 创建我们的网络
 
# 读取网络参数
pretrained_dict = resnet50().state_dict()  # 读取预训练模型参数
model_dict = model().state_dict()       # 读取我们的网络参数
 
# 将pretrained_dict里不属于net_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and (v.shape == model_dict[k].shape)}
 
# 更新修改之后的net_dict
model_dict.update(pretrained_dict)  # 将与 pretrained_dict 中 layer_name 相同的参数更新为 pretrained_dict 的
 
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict) 

其他

1、我是在yolov5-6.0网络的Backbone上修改的,改动并不大,仅仅是替换了、增添了某一模块,所以大部分的权重还是可以进行加载的,另外yolov5代码当中也有对加载预训练权重是否匹配的判断:
【Pytorch】修改网络后加载预训练权重_第1张图片
代码如下:
image-20220516185852798
其实原理一样,判断有没有相同的模块,有的话shape又是否相同,都满足才会放入加载的队列。

有些小伙伴在加载yolov5预训练权重的时候可能还遇到这种问题:明明什么都没有改动,然而预训练权重也会加载错误,这种情况就是使用的yolov5版本不同,其网络结构也不同,不同版本之间有相同模块,但是相同的模块两个版本中又各自有不同的参数格式(卷积核个数、卷积核大小等等),所以模块名称匹配得上,但是shape又不相同。

你可能感兴趣的:(Pytorch学习笔记,深度学习,YOLOv5,pytorch,深度学习,python)