最简单的:
state_dict = torch.load(weight_path)
self.load_state_dict(state_dict,strict=False)
这也是一种方法:
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
if 'epoch' not in checkpoint:
print("can not find epoch")
return model
print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch']))
state_dict_ = checkpoint['state_dict']
state_dict = {}
# convert data_parallal to model
for k in state_dict_:
if k.startswith('module') and not k.startswith('module_list'):
state_dict[k[7:]] = state_dict_[k]
else:
state_dict[k] = state_dict_[k]
model_state_dict = model.state_dict()
# check loaded parameters and created model parameters
for k in state_dict:
if k in model_state_dict:
if state_dict[k].shape != model_state_dict[k].shape:
print('Skip loading parameter {}, required shape{}, '\
'loaded shape{}.'.format(
k, model_state_dict[k].shape, state_dict[k].shape))
state_dict[k] = model_state_dict[k]
else:
print('Drop parameter {}.'.format(k))
for k in model_state_dict:
if not (k in state_dict):
print('No param {}.'.format(k))
state_dict[k] = model_state_dict[k]
model.load_state_dict(state_dict, strict=False)
参数尺度对不上可以用这个:
model_dict = self.state_dict()
pretrained_dict = torch.load(weight_path)
pretrained_dict2 = {}
for k, v in pretrained_dict.items():
if k in model_dict:
if "features.12" not in k and "features.13" not in k:
pretrained_dict2[k] = v
model_dict.update(pretrained_dict2)
self.load_state_dict(model_dict)
print('Total params: %.2fM' % (sum(p.numel() for p in self.parameters()) / 1000000.0))
加载部分参数:
在预训练网络的基础上,修改部分层得到自己的网络,通常我们需要解决的问题包括:
1. 从预训练的模型加载参数
2. 对新网络两部分设置不同的学习率,主要训练自己添加的层
一. 加载参数的方法:
加载参数可以参考apaszke推荐的做法,即删除与当前model不匹配的key。代码片段为:
model = ...
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
二. 不同层设置不同学习率的方法
此部分主要参考PyTorch教程的Autograd machnics部分
2.1 在PyTorch中,每个Variable数据含有两个flag(requires_grad和volatile)用于指示是否计算此Variable的梯度。设置requires_grad = False,或者设置volatile=True,即可指示不计算此Variable的梯度:
for param in model.parameters():
param.requires_grad = False
1
2
注意,在模型测试时,对input_data设置volatile=True,可以节省测试时的显存
2.2 PyTorch的Module.modules()和Module.children()
参考PyTorch document和discuss
在PyTorch中,所有的neural network module都是class torch.nn.Module的子类,在Modules中可以包含其它的Modules,以一种树状结构进行嵌套。当需要返回神经网络中的各个模块时,Module.modules()方法返回网络中所有模块的一个iterator,而Module.children()方法返回所有直接子模块的一个iterator。具体而言:
list(nn.Sequential(nn.Linear(10, 20), nn.ReLU()).modules())
Out[9]:
[Sequential (
(0): Linear (10 -> 20)
(1): ReLU ()
), Linear (10 -> 20), ReLU ()]
In [10]: list(nn.Sequential(nn.Linear(10, 20), nn.ReLU()).children())
Out[10]: [Linear (10 -> 20), ReLU ()]
2.3 选择特定的层进行finetune
先使用Module.children()方法查看网络的直接子模块,将不需要调整的模块中的参数设置为param.requires_grad = False,同时用一个list收集需要调整的模块中的参数。具体代码为:
count = 0
para_optim = []
for k in model.children():
count += 1
# 6 should be changed properly
if count > 6:
for param in k.parameters():
para_optim.append(param)
else:
for param in k.parameters():
param.requires_grad = False
optimizer = optim.RMSprop(para_optim, lr)
原文:https://blog.csdn.net/u012494820/article/details/79068625