加载训练好的BERT参数

将预训练模型中的bert部分取出来加载上去

base_model = BaseModel(config)
base_model_dict = base_model.state_dict()

加载训练好的模型

pre_state_dict = torch.load(Bert_path)

new_state_dict = {k: v for k, v in pre_state_dict.items() if k in base_model_dict}
base_model_dict.update(new_state_dict)
base_model.load_state_dict(base_model_dict)

class BaseModel(nn.Module):
    def __init__(self, config):
        super(BaseModel, self).__init__()
        self.bert_model = BertModel.from_pretrained(config.bert_uncased_path, output_hidden_states=True,
                                                    output_attentions=True)
        for p in self.bert_model.parameters():
            p.requires_grad = False  # 预训练模型加载进来后全部设置为不更新参数,然后再后面加层

    def forward(self, input_ids, attention_mask):
        last_hidden_state = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)[1]
        return last_hidden_state

新的方法

  • 可以把训练好的模型参数加载给新的模型。新的模型和老模型有部分相同。
                    model_orig = Bert_Classify(class_num)
                    model_orig.load_state_dict(torch.load(classifier_name))

                    model = NewModel(one_t, class_num)
                    model.load_state_dict(model_orig.state_dict(), strict=False)

你可能感兴趣的:(加载训练好的BERT参数)