BERT词向量-batch提取

代码来源于网址,做了一点小修改,添加了一点注释。一开始model_path设置为‘bert-base-uncased’,网络模型,字典都会下载在cache/torch/transformer中,之后save保存之后,就可将model_path设置为保存的位置。主要由于标注没有使用wordpiece,因此词由分词之后第一个词的词向量表示。batch从长到短排列,是为了方便之后的rnn结构。
此代码仅仅只能提取bert词向量,无法finetune。如果需要fintune,首先需要继承nn.Module,然后需要在forward中调用extract_features,并把with torch.no_grad去掉。这里的fix embedding只是固定了wordpiece,position embedding 和segment embedding。即输入transformer之前的权重。

import torch
from transformers import *

class Bertvec:
    def __init__(self, model_path, device, fix_embeddings=True):
        self.device = device
        self.model_class = BertModel
        self.tokenizer_class = BertTokenizer
        self.pretrained_weights = model_path
        self.tokenizer = self.tokenizer_class.from_pretrained(self.pretrained_weights)
        self.model = self.model_class.from_pretrained(self.pretrained_weights).to(self.device)
        if fix_embeddings:
            for name, param in self.model.named_parameters():
                if name.startswith('embeddings'):
                    param.requires_grad = False

    def extract_features(self, input_batch_list):
        batch_size = len(input_batch_list)
        words = [sent for sent in input_batch_list]
        word_seq_lengths = torch.LongTensor(list(map(len, words)))
        # 每句句子的长度,获得最长长度
        max_word_seq_len = word_seq_lengths.max().item()
        word_seq_lengths, word_perm_idx = word_seq_lengths.sort(0, descending=True)
        # 长度从长到短排列,并获得由原始排列到从长到短排列的转换顺序 eg:[2,3,1]句子长度,则转换顺序为[1,0,2]
        batch_tokens = []
        batch_token_ids = []
        subword_word_indicator = torch.zeros((batch_size, max_word_seq_len), dtype=torch.int64)
        for idx in range(batch_size):
            one_sent_token = []
            one_subword_word_indicator = []
            for word in input_batch_list[idx]:
                word_tokens = self.tokenizer.tokenize(word)
                # 按照wordpiece分词
                one_subword_word_indicator.append(len(one_sent_token) + 1)
                # 由于分词之后,和输入的句子长度不同,因此需要解决这个问题,这里保存原始句子中词和分词之后的首个词的对应关系
                one_sent_token += word_tokens
                # 针对一句句子,获得分词后的结果
            # 添加 [cls] and [sep] tokens
            one_sent_token = ['[CLS]'] + one_sent_token + ['[SEP]']
            one_sent_token_id = self.tokenizer.convert_tokens_to_ids(one_sent_token)
            # token转换id
            batch_tokens.append(one_sent_token)
            batch_token_ids.append(one_sent_token_id)
            subword_word_indicator[idx, :len(one_subword_word_indicator)] = torch.LongTensor(one_subword_word_indicator)
        token_seq_lengths = torch.LongTensor(list(map(len, batch_tokens)))
        max_token_seq_len = token_seq_lengths.max().item()
        # 计算分词之后最长的句子长度
        batch_token_ids_padded = []
        for the_ids in batch_token_ids:
            batch_token_ids_padded.append(the_ids + [0] * (max_token_seq_len - len(the_ids)))
            # 补充pad
        batch_token_ids_padded_tensor = torch.tensor(batch_token_ids_padded)[word_perm_idx].to(self.device)
        subword_word_indicator = subword_word_indicator[word_perm_idx].to(self.device)
        # 都按照之前得出的转换顺序改变为没有分词之前的句子从长到短的排列。
        with torch.no_grad():
            last_hidden_states = self.model(batch_token_ids_padded_tensor)[0]
        # 提取bert词向量的输出
        batch_word_mask_tensor_list = []
        for idx in range(batch_size):
            one_sentence_vector = torch.index_select(last_hidden_states[idx], 0, subword_word_indicator[idx]).unsqueeze(
                0)
            # 根据对应关系,用分词之后的第一个分词来代表整个词,并添加batch的维度
            batch_word_mask_tensor_list.append(one_sentence_vector)
        batch_word_mask_tensor = torch.cat(batch_word_mask_tensor_list, 0)
        return batch_word_mask_tensor
        
    def save_model(self, path):
        # 将网上下载的模型文件保存到path中
        self.tokenizer.save_pretrained(path)
        self.model.save_pretrained(path)

if __name__ == '__main__':
    input_test_list = [["he", "comes", "from", "--", "encode"],
                       ["One", "way", "of", "measuring", "the", "complexity"],
                       ["I", "encode", "money"]
                       ]
    bert_embedding = Bertvec('./bert-base-uncased/', 'cpu', True)
    batch_features = bert_embedding.extract_features(input_test_list)
	print(batch_features)

你可能感兴趣的:(pytorch)