BERT Pytorch版本 源码解析(二)

BERT Pytorch版本 源码解析(二)

四、BertEmbedding 类解析

BertEmbedding部分是组成 BertModel 的第一部分,今天就来讲讲 BertEmbedding 的内部实现细节。

4.1、Embedding 的组成以及设置

    def __init__(self, config):
        super(BertEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

上面的代码是 BertEmbedding 类的初始化函数,在这块很明显 BertEmbedding 并似乎并没有很特别的地方。总的是设置了三种类型的 embedding,分别是word_embedding,position_embedding,token_type_embedding三种组成。首先,这三种embedding都是用pytorch自带的nn.Embedding 随机生成的,而且它们的向量长度都是 config.hidden_size。之后是一个常见的LayerNorm 以及 Dropout层,这部分就不解释了。

4.2、具体实现

    def forward(self, input_ids, token_type_ids=None):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

首先输入是input_ids或者token_type_ids,input_ids是一个[Batch_size, Seq_length]维度的向量,每一个元素表示对应词表中的index,token_type_ids是对于一个输入存在两个句子的情况,利用 0 和 1 来区分第几个句子的,所以这个部分其实对于大部分任务来说是可以省略的。

然后是关于position_ids的生成,它是自动生成的一个向量,torch.arange(seq_length)是自动生成一个从0开始到seq_length - 1的长度为seq_length的向量。

如果 token_type_ids 是None的情况下则自动生成一个全为0的向量,即所有的输入都是单句的输入。

之后就是利用nn.Embedding来生成三个[Batch_size, Seq_length, Hidden_size]的向量,然后将三个向量进行叠加操作之后进行LayerNorm以及Dropout操作,这就是BertEmbedding的工作原理。

 

你可能感兴趣的:(NLP)