pytorch导入预训练模型部分参数

        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint["epoch"] + 1
        #best_score = checkpoint["best_score"]
        best_score = 0

        print("\t* Training will continue on existing model from epoch {}..."
              .format(start_epoch))

        model_now_dict = model.state_dict()
        print(model_now_dict.keys())

        load_pretrained_dict = (checkpoint["model"])
        print(load_pretrained_dict.keys())
        new_state_dict = {k: v for k, v in load_pretrained_dict.items() if k!="_word_embedding.weight"}
        #new_state_dict = {k: v for k, v in load_pretrained_dict.items()}


        # 1. filter out unnecessary keys
        # 2. overwrite entries in the existing state dict
        model_now_dict.update(new_state_dict)
        model.load_state_dict(model_now_dict)

经过了好几个关卡。

你可能感兴趣的:(python,NLP学习)