关于Pytorch加载模型参数的避坑指南

一、load_state_dict(strict)中参数 strict的使用

load_state_dict(strict)中的参数strict默认是True,这时候就需要严格按照模型中参数的Key值来加载参数,如果增删了模型的结构层,或者改变了原始层中的参数,加载就会报错。

         相反地,如果设置strict为Flase,就可以只加载具有相同名称的参数层,对于修改的模型结构层进行随机赋值。这里需要注意的是,如果只是改变了原来层的参数,但是没有换名称,依然还是会报错。因为根据key值找到对应的层之后,进行赋值,发现参数不匹配。这时候可以将原来的层换个名称,再加载就不会报错了。最后,大家需要注意的是,strict=Flase要谨慎使用,因为很有可能你会一点参数也没加载进来,具体原因请看下文。

二、使用多GPU训练后的模型加载问题

        多GPU训练模型的好处不必多说,毕竟“钞能力”的力量不可小觑。但是,我们需要注意的是,如何加载多GPU训练的模型参数。在执行完函数model = nn.DataParallel(model, device_ids=[0,1,2,3])这条语句后,会给网络中所有的结构层的名称添加module这个字符,此时,如果我们直接使用 model.load_state_dict(torch.load("model.pth"),strict=True)将会报错,如果你灵机一动将strict的参数改为False,程序是不会报错了,但是测试结果会低到离谱,因为压根就没有参数加载进来,每一层的名称前都添加了module,所以名称都是不匹配的。

         这时候有两种解决问题的方法,一是在加载模型前,依旧使用model = nn.DataParallel (model, device_ids=[0,1,2,3])给模型每一层名称前添加module的字符。不过当我们想要单卡去测试模型时就遇到问题了,此时我们需要手动删除掉模型名称中的"module."这7个字符,注意是7个,还有个 .    这样做可以自由地更改模型参数的名称,不仅可以删减前缀"module. ",同时也能增加前缀,这个在模型拼接时会比较方便。

import torch
import torch.nn as nn
import Model.pvt_v2 as PvT
from collections import OrderedDict

net=PvT.pvt_v2_b4()# 
state_dict = torch.load("/datasets/Dset_Jerry/Checkpoint/CC-CXRI-P/PvT_B32_S384/PvT_18.pkl")  # 模型可以保存为pth文件,也可以为pt文件。
# create new OrderedDict that does not contain module.
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:]  # remove module.
    new_state_dict[name] = v  # 新字典的key值对应的value为一一对应的值。
# load params
net.load_state_dict(new_state_dict, strict=True)  # 重新加载这个模型。

你可能感兴趣的:(机器学习,深度学习)