Pytorch加载模型不完全匹配 & 只加载部分参数权重 load (pth文件) & 从网络加载权重(URL)

加载模型不完全匹配

model.load_state_dict(torch.load(weight_path), strict=False)

当权重中的key和网络中匹配就加载,不匹配就跳过

如果strict是True,那必须完全匹配,不然就报错

默认是True


但是注意,如果是像英文模型迁移到中文,改了class num的话,例如由26改为3600,这时模型不匹配用它是解决不了的,因为此时模型的key名字是对应的上的,只是权重的size不同 看

只加载部分参数权重

如果发生上述情况的话,那就需要把加载到的模型的中,不匹配的那几项删掉,然后加载其他项

x = torch.load(self.weight)
del x['char_recognizer.classifier.bias']
del x['char_recognizer.classifier.weight']
self.load_state_dict(x, strict=False)

或者

# Use when some parts of pretrained model are not needed
# pretrained_dict = checkpoint['state_dict']
# model_dict = model.state_dict()

# # 1. filter out unnecessary keys
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# # 2. overwrite entries in the existing state dict
# model_dict.update(pretrained_dict) 
# # 3. load the new state dict
# model.load_state_dict(model_dict)

或者

load pretrained model 然后通过args.finetune_ignore指定忽略的参数

code is from DAB-DETR

def clean_state_dict(state_dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k[:7] == 'module.':
            k = k[7:]  # remove `module.`
        new_state_dict[k] = v
    return new_state_dict
    
#code is from DAB-DETR
if not args.resume and args.pretrain_model_path:
    checkpoint = torch.load(args.pretrain_model_path, map_location='cpu')['model']
    from collections import OrderedDict
    _ignorekeywordlist = args.finetune_ignore if args.finetune_ignore else []
    ignorelist = []

    def check_keep(keyname, ignorekeywordlist):
        for keyword in ignorekeywordlist:
            if keyword == keyname:
                ignorelist.append(keyname)
                return False
        return True

    # logger.info("Ignore keys: {}".format(json.dumps(ignorelist, indent=2)))
    _tmp_st = OrderedDict({k:v for k, v in clean_state_dict(checkpoint).items() if check_keep(k, _ignorekeywordlist)})
    _load_output = model_without_ddp.load_state_dict(_tmp_st, strict=False)
        

从网络加载权重

#code is from DETR
checkpoint = torch.hub.load_state_dict_from_url(
    url, map_location='cpu', check_hash=True)

你可能感兴趣的:(Pytorch)