pytorch加载预训练 加载部分参数

最简单的:

   state_dict = torch.load(weight_path)
   self.load_state_dict(state_dict,strict=False)

 

这也是一种方法:

  checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)

  if 'epoch' not in checkpoint:
    print("can not find epoch")
    return model
  print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch']))
  state_dict_ = checkpoint['state_dict']
  state_dict = {}
  
  # convert data_parallal to model
  for k in state_dict_:
    if k.startswith('module') and not k.startswith('module_list'):
      state_dict[k[7:]] = state_dict_[k]
    else:
      state_dict[k] = state_dict_[k]
  model_state_dict = model.state_dict()

  # check loaded parameters and created model parameters
  for k in state_dict:
    if k in model_state_dict:
      if state_dict[k].shape != model_state_dict[k].shape:
        print('Skip loading parameter {}, required shape{}, '\
              'loaded shape{}.'.format(
          k, model_state_dict[k].shape, state_dict[k].shape))
        state_dict[k] = model_state_dict[k]
    else:
      print('Drop parameter {}.'.format(k))
  for k in model_state_dict:
    if not (k in state_dict):
      print('No param {}.'.format(k))
      state_dict[k] = model_state_dict[k]
  model.load_state_dict(state_dict, strict=False)

参数尺度对不上可以用这个:

     model_dict = self.state_dict()
     pretrained_dict = torch.load(weight_path)

     pretrained_dict2 = {}

     for k, v in pretrained_dict.items():
            if k in model_dict:
                if "features.12" not in k and "features.13" not in k:
                    pretrained_dict2[k] = v

      model_dict.update(pretrained_dict2)
      self.load_state_dict(model_dict)
      print('Total params: %.2fM' % (sum(p.numel() for p in self.parameters()) / 1000000.0))

 

加载部分参数:

在预训练网络的基础上,修改部分层得到自己的网络,通常我们需要解决的问题包括: 
1. 从预训练的模型加载参数 
2. 对新网络两部分设置不同的学习率,主要训练自己添加的层 
一. 加载参数的方法: 
加载参数可以参考apaszke推荐的做法,即删除与当前model不匹配的key。代码片段为:

model = ...
model_dict = model.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)

二. 不同层设置不同学习率的方法 
此部分主要参考PyTorch教程的Autograd machnics部分 
2.1 在PyTorch中,每个Variable数据含有两个flag(requires_grad和volatile)用于指示是否计算此Variable的梯度。设置requires_grad = False,或者设置volatile=True,即可指示不计算此Variable的梯度:

for param in model.parameters():
    param.requires_grad = False
1
2
注意,在模型测试时,对input_data设置volatile=True,可以节省测试时的显存 
2.2 PyTorch的Module.modules()和Module.children() 
参考PyTorch document和discuss 
在PyTorch中,所有的neural network module都是class torch.nn.Module的子类,在Modules中可以包含其它的Modules,以一种树状结构进行嵌套。当需要返回神经网络中的各个模块时,Module.modules()方法返回网络中所有模块的一个iterator,而Module.children()方法返回所有直接子模块的一个iterator。具体而言:

list(nn.Sequential(nn.Linear(10, 20), nn.ReLU()).modules())
Out[9]:
[Sequential (
   (0): Linear (10 -> 20)
   (1): ReLU ()
 ), Linear (10 -> 20), ReLU ()]

In [10]: list(nn.Sequential(nn.Linear(10, 20), nn.ReLU()).children())
Out[10]: [Linear (10 -> 20), ReLU ()]

2.3 选择特定的层进行finetune 
先使用Module.children()方法查看网络的直接子模块,将不需要调整的模块中的参数设置为param.requires_grad = False,同时用一个list收集需要调整的模块中的参数。具体代码为:

count = 0
    para_optim = []
    for k in model.children():
        count += 1
        # 6 should be changed properly
        if count > 6:
            for param in k.parameters():
                para_optim.append(param)
        else:
            for param in k.parameters():
                param.requires_grad = False
optimizer = optim.RMSprop(para_optim, lr)


原文:https://blog.csdn.net/u012494820/article/details/79068625 

你可能感兴趣的:(torch)