pytorch加载预训练模型

首先是键name一致的情况、比较简单:

	#这是我们自己网络模型参数的有序字典形式(网络参数名:值)
 	net_dict = net.state_dict()
    #这是实际加载的预训练好的网络模型参数的有序字典形式    
    pretrained_dict = torch.load(pretrained_path)
    #从预训练的参数中加载我们的网络中需要的模型参数(这个很重要、有时需要冻结某一层的参数、可用这条语句从预训练的整个网络参数中筛选出我们需要的某一层的参数)
    pretrained_dict = {
     k: v for k, v in pretrained_dict.items() if k in net_dict}
    #字典的updata方法,进行字典的更新(个人感觉不是必要的)
    net_dict.update(pretrained_dict)
    #按照键与键的对应关系、加载网络参数    
    net.load_state_dict(net_dict)

然后是键name不一致的情况、也不难:
第一种情况:
当只想加载某一层的参数时,却发现预训练的模型参数名字 与 某一层的参数名字 仅仅差了 一个前缀的关系,那么可用切片的方式,返回一个新的有序字典,然后根据新的有序字典再加载参数。
先来看一下net网络trans这一层的参数:

model_dict = net.trans.state_dict()
print(model_dict.keys())

结果:
odict_keys([‘conv0.h2h_0.weight’, ‘conv0.h2h_0.bias’, ‘conv0.l2l_0.weight’, ‘conv0.l2l_0.bias’, ‘conv0.bnh_0.weight’, ‘conv0.bnh_0.bias’, ‘conv0.bnh_0.running_mean’, ‘conv0.bnh_0.running_var’, …
再来看一下整个预训练模型的参数:

 file = "C:\\Users\\Hou bin\\Desktop\\MINet_Res50.pth"
 pretrained_dict = torch.load(file, map_location='cpu')
 print(pretrained_dict.keys())

‘trans.conv0.h2h_0.weight’, ‘trans.conv0.h2h_0.bias’, ‘trans.conv0.l2l_0.weight’, ‘trans.conv0.l2l_0.bias’, ‘trans.conv0.bnh_0.weight’, ‘trans.conv0.bnh_0.bias’, ‘trans.conv0.bnh_0.running_mean’,…
有啥区别呢?
就是差了一个前缀!
如何加载?
代码如下:

 model_dict = net.trans.state_dict()#以有序字典形式返回trans这一层的全部参数
 file = "C:\\Users\\Hou bin\\Desktop\\MINet_Res50.pth"
 pretrained_dict = torch.load(file, map_location='cpu')
 pretrained_dict = {
     k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}  #把预训练网络参数trans这一层的参数名字前的前缀通过切片去掉k[6:]: v,返回新的有序字典、这样加载的时候参数名才能对应啊
 #更新现有的model_dict
 model_dict.update(pretrained_dict)#这一步也可以不需要、下一行代码直接加载pretrained_dict
 net.trans.load_state_dict(model_dict)

第二种情况:
有时候你的模型保存时含有 nn.DataParallel时,就会发现所有的dict都会有 module的前缀。
这时候加载含有module前缀的模型时,可能会出错。其实你只要移除这些前缀即可
pytorch加载预训练模型_第1张图片

方法同上!

你可能感兴趣的:(深度学习图像分类,python,深度学习,神经网络)