pytorch 载入模型参数finetune训练

pytorch 在finetune重训练时,采用torch.load()方式载入模型
经常会报错。
这里给出一种load方式,若模型中存在相同的tensor(名字和大小一致)则载入,否则对模型中tensor只做初始化处理。
代码如下:

def check_keys(model, pretrained_state_dict):
    ckpt_keys = set(pretrained_state_dict.keys())
    model_keys= set(model.state_dict().keys())
    used_pretrained_keys = model_keys & ckpt_keys
    unused_pretrained_keys = ckpt_keys-model_keys
    missing_keys = model_keys  - ckpt_keys
    assert len(uesd_pretrained_keys) > 0, 'load None from pretrained checkpoint'
    return True
def remove_prefix(state_dict, prefix):
   f = lambda x:x.split(prefix,1)[-1] if x.startswith(prefix) else x
   return {f(key):value for key,value in state_dict.items()}
    
def load_model(model,pretrained_path,load_to_cpu):
    if load_to_cpu:
        pretrained_dict  = torch.load(pretrained_path, map_location=lambda storage, loc:storage)
   else:
       device = torch.cuda.current_device()
       pretrained_dict =torch.load(pretrained_path,map_location=lambda storage, loc:storage.cuda(device))
   if "state_dict" in pretrained_dict.keys():
        pretrained_dict=remove_prefix(pretrained_dict['state_dict'],'module.')
   else:
       pretrained_dict = remove_prefix(pretrained_dict, 'module.')
  check_keys(model, pretrained_dict)
  model.load_state_dict(pretrained_dict,strict=False)
  return model

代码来源

你可能感兴趣的:(pytorch 载入模型参数finetune训练)