pytorch加载部分训练模型

加载部分预训练模型

    checkpoint_finetune = torch.load('checkpoint1/model_best.pth.tar')
    model_dict = model.state_dict()
    pretrained_dict = checkpoint_finetune['state_dict']
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict) 
    model.load_state_dict(model_dict)

load_state_dict中的strict关键字

strict关键字只是说不匹配的关键字都不加载

你可能感兴趣的:(pytorch加载部分训练模型)