pytorch加载模型错误 RuntimeError: Error(s) in loading state_dict for Model: Missing key(s) in state_dict

模型在保存时侯以键对值保存,同时在加载时根据现在网络的键值查找模型对应的键值,然后加载。一般报错是因为模型和网络的键值不匹配。

1、最常见的问题是键值多了或者少了 module.

此种情况是模型在DataParallel或者DDP训练后保存的键值有module. ,对应的网络的键值则没有module.

1)可以通过:

model = nn.DataParallel(model)

将模型的键值加上module.

2) 也可以通过遍历模型的键对值修改键值。

   如:加载模型时删除多余的module.  代码如下

state_dict = torch.load(load_path)
for key, param in state_dict.items():
    if key.startswith('module.'):        #键值包含‘module.’ 则删除 
        state_dict[key[7:]] = param          
        state_dict.pop(key)
net.load_state_dict(state_dict)
        

2、详解load_state_dict(state_dict, False)的False参数

很多教程说名字不匹配直接添加False参数即可,但是这里需要注意一个大坑。

如果模型的键值和网络的键值完全不匹配,那么模型就没有加载预训练参数,虽然不再报错。

该False参数作用在于 非严格匹配加载模型,可以下面几种情况进行分析

1)模型包含网络的部分参数

比如说模型是resnet101模型,你现在的网络是resnet50。再假设resnet50的参数名包含在resnet101的参数中,那么直接使用False会为你的网络resnet50加载键值相同的参数。这样就避免了对resnet101的每个键对值进行循环匹配,看是否是resnet50需要的。

2)模型完全不包含网络的参数

情况如1,模型有100个参数,都包含'module.' ,网络也有100个参数,都没有'module.' 。这种情况下如果参数设置为False,会发现没有任何键值能匹配上,因此网络就不会加载任何参数。

3)再介绍一个False使用场景

比如蒸馏网络PISR中,教师网络包含Encoder和Decoder两部分,学生网络由其中的Decoder部分组成,所以在训练学生网络时,如果要加载教师网络保存的预训练模型,设置False会自动识别Decoder部分键值相同,然后加载。

综上,设置False参数后依旧是按照键值查询加载参数的,有多少键值匹配,就加载多少模型的参数。

 

3、只要参数尺寸相同,就能加载

比如说我有一个10层网络的模型,还有一个3层的网络。我想把其中第9层的参数加载到现在网络的1层。如果参数的尺寸相同,就可以遍历键对值。将参数加载到想要的键值中。

state_dict = torch.load(load_path)
new_state_dict = []
for key, param in state_dict.items():
    if 'conv9' in key:        # 如果找到conv9对应的参数,将其键值替换为网络的键
        new_state_dict[key.replace('conv9', 'conv1')] = param   
net.load_state_dict(new_state_dict)

你可能感兴趣的:(torch,python,深度学习,pytorch,加载模型,torch.load,state_dict)