FLAT代码解读(2)-模型

论文 FLAT: Chinese NER Using Flat-Lattice Transformer(ACL 2020)

我们直接看模型部分,模型的输入部分在上一篇中已经详细解读过。

V0版本:without bert

model = Lattice_Transformer_SeqLabel(embeddings['lattice'],
                 embeddings['bigram'],
                 args.hidden,  # 128
                 len(vocabs['label']),  # label_size
                 args.head, args.layer,  # 8, 1
                 args.use_abs_pos,  # False  是否使用绝对位置编码
                 args.use_rel_pos,  # True  是否使用相对位置编码
                 args.learn_pos,  # False  绝对和相对位置编码是否可学习(是否计算梯度)
                 args.add_pos,  # False  是否在transformer_layer中通过concat加入位置信息
                 args.pre, args.post,  # '', 'an'
                 args.ff,  # 128x3  feed-forward中间层节点个数
                 args.scaled, dropout,  # False, dropout
                 args.use_bigram,  # 1
                 mode, device,
                 vocabs,
                 max_seq_len=max_seq_len,
                 rel_pos_shared=args.rel_pos_shared,  # True
                 k_proj=args.k_proj,  # False
                 q_proj=args.q_proj,  # True
                 v_proj=args.v_proj,  # True
                 r_proj=args.r_proj,  # True
                 self_supervised=args.self_supervised,  # False
                 attn_ff=args.attn_ff,  # False  是否在self-attn层最后加一个linear层
                 pos_norm=args.pos_norm,  # False  是否对位置编码进行norm
                 ff_activate=args.ff_activate,  # relu
                 abs_pos_fusion_func=args.abs_pos_fusion_func,  # nonlinear_add
                 embed_dropout_pos=args.embed_dropout_pos,  # 0
                 four_pos_shared=args.four_pos_shared,  # True 只针对相对位置编码,指4个位置编码是否共享权重
                 four_pos_fusion=args.four_pos_fusion,  # ff_two  4个位置编码的融合方法
                 four_pos_fusion_shared=args.four_pos_fusion_shared,  # True 是否共享4个位置融合后形成的pos
                 use_pytorch_dropout=args.use_pytorch_dropout  # 0
                 )

下面我们对Lattice_Transformer_SeqLabel的一些关键代码块进行解读。

整体结构

  1. 位置编码
        if self.use_rel_pos:
            pe = get_embedding(max_seq_len, hidden_size, rel_pos_init=self.rel_pos_init)  # [2*max_seq_len+1, hidden_size]
            pe_sum = pe.sum(dim=-1, keepdim=True)  # [2*max_seq_len+1, 1]
            if self.pos_norm:
                with torch.no_grad():
                    pe = pe/pe_sum
            self.pe = nn.Parameter(pe, requires_grad=self.learnable_position)
            if self.four_pos_shared:
                self.pe_ss = self.pe
                self.pe_se = self.pe
                self.pe_es = self.pe
                self.pe_ee = self.pe
            else:
                self.pe_ss = nn.Parameter(copy.deepcopy(pe), requires_grad=self.learnable_position)
                self.pe_se = nn.Parameter(copy.deepcopy(pe), requires_grad=self.learnable_position)
                self.pe_es = nn.Parameter(copy.deepcopy(pe), requires_grad=self.learnable_position)
                self.pe_ee = nn.Parameter(copy.deepcopy(pe), requires_grad=self.learnable_position)

这里采用三角函数位置编码:

def get_embedding(max_seq_len, embedding_dim, padding_idx=None, rel_pos_init=0):
    """Build sinusoidal embeddings.
    This matches the implementation in tensor2tensor, but differs slightly
    from the description in Section 3.5 of "Attention Is All You Need".
    rel pos init:
    如果是0,那么从-max_len到max_len的相对位置编码矩阵就按0-2*max_len来初始化,
    如果是1,那么就按-max_len,max_len来初始化
    """
    num_embeddings = 2*max_seq_len+1
    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
    if rel_pos_init == 0:
        emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)  # [num_embeddings, half_dim]
    else:
        emb = torch.arange(-max_seq_len, max_seq_len+1, dtype=torch.float).unsqueeze(1)*emb.unsqueeze(0)
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)  # [num_embeddings, embedding_dim]
    if embedding_dim % 2 == 1:
        # zero pad
        emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
    if padding_idx is not None:
        emb[padding_idx, :] = 0
    return emb
  1. 是否使用bigram (非必要)
if self.use_bigram:
    self.bigram_size = self.bigram_embed.embedding.weight.size(1)
    self.char_input_size = self.lattice_embed.embedding.weight.size(1) + self.bigram_size
else:
    self.char_input_size = self.lattice_embed.embedding.weight.size(1)

self.lex_input_size = self.lattice_embed.embedding.weight.size(1)

如果使用bigram,则把数据中bigram的embedding信息也加上,否则只有lattice的embedding信息。

  1. Transformer Encoder
self.char_proj = nn.Linear(self.char_input_size, self.hidden_size)
self.lex_proj = nn.Linear(self.lex_input_size, self.hidden_size)

self.encoder = Transformer_Encoder(self.hidden_size, self.num_heads, self.num_layers,
                                   relative_position=self.use_rel_pos, ...)

  1. 网络最后一层
self.output = nn.Linear(self.hidden_size, self.label_size)
self.crf = get_crf_zero_init(self.label_size)
self.loss_func = nn.CrossEntropyLoss(ignore_index=-100)
  1. 模型forward函数
def forward(self, lattice, bigrams, seq_len, lex_num, pos_s, pos_e,
                target, chars_target=None):
        batch_size = lattice.size(0)
        max_seq_len_and_lex_num = lattice.size(1)
        max_seq_len = bigrams.size(1)

        raw_embed = self.lattice_embed(lattice)  # lattice embedding
        
        if self.use_bigram:
            bigrams_embed = self.bigram_embed(bigrams)
            bigrams_embed = torch.cat([bigrams_embed,
                                       torch.zeros(size=[batch_size, max_seq_len_and_lex_num - max_seq_len,
                                                         self.bigram_size]).to(bigrams_embed)], dim=1)
            # [bs, max_seq_len_and_lex_num, lattice_embed_size+bigram_embed_size]
            raw_embed_char = torch.cat([raw_embed, bigrams_embed], dim=-1)  
        else:
            raw_embed_char = raw_embed

        if self.embed_dropout_pos == '0':
            raw_embed_char = self.embed_dropout(raw_embed_char)
            raw_embed = self.gaz_dropout(raw_embed)

        # [bs, max_seq_len_and_lex_num, hidden_size]
        embed_char = self.char_proj(raw_embed_char)
        char_mask = seq_len_to_mask(seq_len, max_len=max_seq_len_and_lex_num).bool()  # [bs, max_len]
        # Fills elements of self_tensor with 0 where char_mask is False
        embed_char.masked_fill_(~(char_mask.unsqueeze(-1)), 0)  # [bs, max_len, 1*hidden_size]

        # [bs, max_seq_len_and_lex_num, hidden_size]
        embed_lex = self.lex_proj(raw_embed)
        lex_mask = (seq_len_to_mask(seq_len + lex_num).bool() ^ char_mask.bool())  # 后缀词汇部分为True
        # Fills elements with 0 where lex_mask is False, 即前面char部分置为0
        embed_lex.masked_fill_(~(lex_mask).unsqueeze(-1), 0)  # [bs, max_len, 1*hidden_size]

        embedding = embed_char + embed_lex

        encoded = self.encoder(embedding, seq_len, lex_num=lex_num, pos_s=pos_s, pos_e=pos_e)

        encoded = encoded[:, :max_seq_len, :]  # 仅用char部分做预测
        pred = self.output(encoded)

        mask = seq_len_to_mask(seq_len).bool()

        if self.training:
            loss = self.crf(pred, target, mask).mean(dim=0)
          
            if self.batch_num == 327:
                print('{} loss:{}'.format(self.batch_num,loss))
                exit()

            return {'loss': loss}
        else:
            pred, path = self.crf.viterbi_decode(pred, mask)
            result = {'pred': pred}
            if self.self_supervised:
                chars_pred = self.output_self_supervised(encoded)
                result['chars_pred'] = chars_pred

            return result

  • embedding = embed_char + embed_lex 为整个模型的Embedding
    • embed_char通过char_mask使得后缀的词汇部分为0,保留前面的字的部分embeddings
    • embed_lex通过lex_mask使得前面的字的部分为0,保留后缀的词汇部分的embeddings
    • 所以embed_char和embed_lex通过self.char_projself.lex_proj映射分别学习(train)对应的字和词汇的embeddings
  • encoded = self.encoder(...) 为Self-Attention, Add & Norm, FFN, Add & Norm等层
  • pred = self.output(encoded) 为最后一个Linear层
  • loss = self.crf(pred, target, mask) 为最后的CRF层

以上内容与论文中的Figure 2整体框架图对应。

关键细节

接下来我们分析self.encoder(...)即Transformer_Encoder的主要内容:

  1. 是否融合4种位置信息得到相对位置编码,即论文中的公式(8)
if self.four_pos_fusion_shared:
    self.four_pos_fusion_embedding = \
        Four_Pos_Fusion_Embedding(self.pe, self.four_pos_fusion, self.pe_ss, self.pe_se, self.pe_es, self.pe_ee,
                                  self.max_seq_len, self.hidden_size, self.mode)
else:
    self.four_pos_fusion_embedding = None
class Four_Pos_Fusion_Embedding(nn.Module):
    def __init__(self, pe, four_pos_fusion, pe_ss, pe_se, pe_es, pe_ee,
                 max_seq_len, hidden_size, mode):
        super().__init__()
        self.mode = mode
        self.hidden_size = hidden_size
        self.max_seq_len = max_seq_len
        self.pe_ss = pe_ss
        self.pe_se = pe_se
        self.pe_es = pe_es
        self.pe_ee = pe_ee
        self.pe = pe  # [2*max_seq_len+1, hidden_size]
        self.four_pos_fusion = four_pos_fusion
        if self.four_pos_fusion == 'ff':
            self.pos_fusion_forward = nn.Sequential(nn.Linear(self.hidden_size*4, self.hidden_size),
                                                    nn.ReLU(inplace=True))
        if self.four_pos_fusion == 'ff_linear':
            self.pos_fusion_forward = nn.Linear(self.hidden_size*4, self.hidden_size)
          
        elif self.four_pos_fusion == 'ff_two':
            self.pos_fusion_forward = nn.Sequential(nn.Linear(self.hidden_size*2, self.hidden_size),
                                                    nn.ReLU(inplace=True))
        elif self.four_pos_fusion == 'attn':
            self.w_r = nn.Linear(self.hidden_size, self.hidden_size)
            self.pos_attn_score = nn.Sequential(nn.Linear(self.hidden_size*4, self.hidden_size*4),
                                                nn.ReLU(),
                                                nn.Linear(self.hidden_size*4, 4),
                                                nn.Softmax(dim=-1))
        elif self.four_pos_fusion == 'gate':
            self.w_r = nn.Linear(self.hidden_size, self.hidden_size)
            self.pos_gate_score = nn.Sequential(nn.Linear(self.hidden_size*4, self.hidden_size*2),
                                                nn.ReLU(),
                                                nn.Linear(self.hidden_size*2, 4*self.hidden_size))

    def forward(self, pos_s, pos_e):
        batch = pos_s.size(0)
        max_seq_len = pos_s.size(1)
        # [bs, max_seq_len, 1] - [bs, 1, max_seq_len] = [bs, max_seq_len, max_seq_len]
        pos_ss = pos_s.unsqueeze(-1) - pos_s.unsqueeze(-2)
        pos_se = pos_s.unsqueeze(-1) - pos_e.unsqueeze(-2)
        pos_es = pos_e.unsqueeze(-1) - pos_s.unsqueeze(-2)
        pos_ee = pos_e.unsqueeze(-1) - pos_e.unsqueeze(-2)
        
        # [bs, max_seq_len, max_seq_len, hidden_size]
        pe_ss = self.pe_ss[(pos_ss).view(-1) + self.max_seq_len].view(size=[batch, max_seq_len, max_seq_len, -1])
        pe_se = self.pe_se[(pos_se).view(-1) + self.max_seq_len].view(size=[batch, max_seq_len, max_seq_len, -1])
        pe_es = self.pe_es[(pos_es).view(-1) + self.max_seq_len].view(size=[batch, max_seq_len, max_seq_len, -1])
        pe_ee = self.pe_ee[(pos_ee).view(-1) + self.max_seq_len].view(size=[batch, max_seq_len, max_seq_len, -1])

        if self.four_pos_fusion == 'ff':
            pe_4 = torch.cat([pe_ss, pe_se, pe_es, pe_ee], dim=-1)
            rel_pos_embedding = self.pos_fusion_forward(pe_4)
        if self.four_pos_fusion == 'ff_linear':
            pe_4 = torch.cat([pe_ss, pe_se, pe_es, pe_ee], dim=-1)
            rel_pos_embedding = self.pos_fusion_forward(pe_4)
        if self.four_pos_fusion == 'ff_two':
            pe_2 = torch.cat([pe_ss, pe_ee], dim=-1)
            rel_pos_embedding = self.pos_fusion_forward(pe_2)
        elif self.four_pos_fusion == 'attn':
            pe_4 = torch.cat([pe_ss, pe_se, pe_es, pe_ee], dim=-1)
            attn_score = self.pos_attn_score(pe_4)
            pe_4_unflat = self.w_r(pe_4.view(batch, max_seq_len, max_seq_len, 4, self.hidden_size))
            pe_4_fusion = (attn_score.unsqueeze(-1) * pe_4_unflat).sum(dim=-2)
            rel_pos_embedding = pe_4_fusion
        elif self.four_pos_fusion == 'gate':
            pe_4 = torch.cat([pe_ss, pe_se, pe_es, pe_ee], dim=-1)
            gate_score = self.pos_gate_score(pe_4).view(batch,max_seq_len,max_seq_len,4,self.hidden_size)
            gate_score = F.softmax(gate_score, dim=-2)
            pe_4_unflat = self.w_r(pe_4.view(batch, max_seq_len, max_seq_len, 4, self.hidden_size))
            pe_4_fusion = (gate_score * pe_4_unflat).sum(dim=-2)
            rel_pos_embedding = pe_4_fusion

        return rel_pos_embedding

  • forward函数中传入pos_s(Head)pos_e(Tail)来得到4种位置信息pos_ss, pos_se, pos_es, pos_ee
  • 将4种位置信息转换成对应的位置编码pe_ss, pe_se, pe_es, pe_ee
  • 最后将4种位置编码进行融合。这里融合的方式有5种,ff就是带非线性激活函数的全连接,attn就是先计算出每个位置编码的权重,再加权求和,gateattn类似,只不过计算加权多了一个维度。
    默认采用ff_two,得到4种位置编码融合后形成的位置编码rel_pos_embedding,维度为[bs, max_seq_len, max_seq_len, hidden_size]
  1. 核心部分Transform_Encoder_Layer()
for i in range(self.num_layers):
    setattr(self, 'layer_{}'.format(i), Transformer_Encoder_Layer(hidden_size, num_heads,...)

模型forward函数

    def forward(self, inp, seq_len, lex_num=0, pos_s=None, pos_e=None, print_=False):
        output = inp
        if self.relative_position:
            if self.four_pos_fusion_shared and self.lattice:
                rel_pos_embedding = self.four_pos_fusion_embedding(pos_s, pos_e)
            else:
                rel_pos_embedding = None
        else:
            rel_pos_embedding = None
        for i in range(self.num_layers):
            now_layer = getattr(self, 'layer_{}'.format(i))  # 多层 Transformer_Encoder_Layer
            output = now_layer(output, seq_len, lex_num=lex_num, pos_s=pos_s, pos_e=pos_e,
                               rel_pos_embedding=rel_pos_embedding, print_=print_)

        output = self.layer_preprocess(output)
        return output

可以看到now_layer在逐层调用Transformer_Encoder_Layer()进行前向传播,这里将相对位置编码rel_pos_embedding也传了进去。

因此,我们对最核心代码块Transformer_Encoder_Layer()的关键部分进行分析:

  1. 模型forward函数
      self.ff = Positionwise_FeedForward([hidden_size, ff_size, hidden_size], self.dropout,ff_activate=self.ff_activate,
                                           use_pytorch_dropout=self.use_pytorch_dropout)

    def forward(self, inp, seq_len, lex_num=0, pos_s=None, pos_e=None, rel_pos_embedding=None,
                print_=False):
        output = inp
        output = self.layer_preprocess(output)
        if self.lattice:
            if self.relative_position:
                if rel_pos_embedding is None:
                    rel_pos_embedding = self.four_pos_fusion_embedding(pos_s,pos_e)
                output = self.attn(output, output, output, seq_len, pos_s=pos_s, pos_e=pos_e, lex_num=lex_num,
                                   rel_pos_embedding=rel_pos_embedding)
            else:
                output = self.attn(output, output, output, seq_len, lex_num)
        else:
            output = self.attn(output, output, output, seq_len)

        output = self.layer_postprocess(output)
        output = self.layer_preprocess(output)

        output = self.ff(output, print_)

        output = self.layer_postprocess(output)

        return output
  • self.attn()为Self-Attention层
  • self.layer_postprocess() 可以执行Add & Norm操作 (注意这里作者实现中有一些bug,并没有实现残差连接)
  • self.ff()为FFN层
  1. Self-Attention层中融入相对位置编码信息
self.attn = MultiHead_Attention_Lattice_rel_save_gpumm(self.hidden_size, self.num_heads,
                                                    pe=self.pe,
                                                    pe_ss=self.pe_ss,
                                                    pe_se=self.pe_se,
                                                    pe_es=self.pe_es,
                                                    pe_ee=self.pe_ee,
                                                    scaled=self.scaled,
                                                    mode=self.mode,
                                                    max_seq_len=self.max_seq_len,
                                                    dvc=self.dvc,
                                                    k_proj=self.k_proj,
                                                    q_proj=self.q_proj,
                                                    v_proj=self.v_proj,
                                                    r_proj=self.r_proj,
                                                    attn_dropout=self.dropout['attn'],
                                                    ff_final=self.attn_ff,  # False
                                                    four_pos_fusion=self.four_pos_fusion,
                                                    use_pytorch_dropout=self.use_pytorch_dropout)

我们详细来看self-attention层中的关键部分:

class MultiHead_Attention_Lattice_rel_save_gpumm(nn.Module):
    def __init__(self, hidden_size, num_heads, ...):
        ... 省略一些
        self.per_head_size = self.hidden_size // self.num_heads
        self.w_k = nn.Linear(self.hidden_size, self.hidden_size)
        self.w_q = nn.Linear(self.hidden_size, self.hidden_size)
        self.w_v = nn.Linear(self.hidden_size, self.hidden_size)
        self.w_r = nn.Linear(self.hidden_size, self.hidden_size)
        self.w_final = nn.Linear(self.hidden_size, self.hidden_size)
        self.u = nn.Parameter(torch.Tensor(self.num_heads, self.per_head_size))
        self.v = nn.Parameter(torch.Tensor(self.num_heads, self.per_head_size))

    def forward(self, key, query, value, seq_len, lex_num, pos_s, pos_e, rel_pos_embedding):
        if self.k_proj:
            key = self.w_k(key)
        if self.q_proj:
            query = self.w_q(query)
        if self.v_proj:
            value = self.w_v(value)
        if self.r_proj:
            # [bs, max_seq_len, max_seq_len, hidden_size] 
            rel_pos_embedding = self.w_r(rel_pos_embedding)

        batch = key.size(0)
        max_seq_len = key.size(1)
        
        # batch * seq_len * n_head * per_head_size
        key = torch.reshape(key, [batch, max_seq_len, self.num_heads, self.per_head_size])
        query = torch.reshape(query, [batch, max_seq_len, self.num_heads, self.per_head_size])
        value = torch.reshape(value, [batch, max_seq_len, self.num_heads, self.per_head_size])
        rel_pos_embedding = torch.reshape(rel_pos_embedding,
                                          [batch, max_seq_len, max_seq_len, self.num_heads, self.per_head_size])
        
        # batch * n_head * seq_len * per_head_size
        key = key.transpose(1, 2)
        query = query.transpose(1, 2)
        value = value.transpose(1, 2)

        # batch * n_head * per_head_size * key_len
        key = key.transpose(-1, -2)

        # u_for_c: 1(batch broadcast) * num_heads * 1 * per_head_size
        u_for_c = self.u.unsqueeze(0).unsqueeze(-2)
        query_and_u_for_c = query + u_for_c
        # batch * n_head * seq_len * seq_len
        A_C = torch.matmul(query_and_u_for_c, key)  

        rel_pos_embedding_for_b = rel_pos_embedding.permute(0, 3, 1, 4, 2)
        # after above, rel_pos_embedding: batch * num_head * query_len * per_head_size * key_len
        query_for_b = query.view([batch, self.num_heads, max_seq_len, 1, self.per_head_size])
        # after above, query_for_b: batch * num_head * query_len * 1 * per_head_size

        query_for_b_and_v_for_d = query_for_b + self.v.view(1, self.num_heads, 1, 1, self.per_head_size)
        B_D = torch.matmul(query_for_b_and_v_for_d, rel_pos_embedding_for_b).squeeze(-2)

        attn_score_raw = A_C + B_D

        if self.scaled:
            attn_score_raw  = attn_score_raw / math.sqrt(self.per_head_size)

        mask = seq_len_to_mask(seq_len+lex_num).bool().unsqueeze(1).unsqueeze(1)
        attn_score_raw_masked = attn_score_raw.masked_fill(~mask, -1e15)

        attn_score = F.softmax(attn_score_raw_masked,dim=-1)
        attn_score = self.dropout(attn_score)

        value_weighted_sum = torch.matmul(attn_score, value)
        result = value_weighted_sum.transpose(1, 2).contiguous(). \
            reshape(batch, max_seq_len, self.hidden_size)

        return result  # [batch, max_seq_len, hidden_size]

  • 代码变量名中出现的a, b, c, d或其大写分别表示了论文中公式(11)的第一、二、三、四项。
  • A_C表示论文中公式(11)中第一项和第三项的和
  • B_D表示论文中公式(11)中第二项和第四项的和

至此,我们对FLAT模型结构的关键代码进行了一个较为详细地解读。

此外,论文作者还提供了一个V1版本,和V0版本的主要区别是使用了BERT embedding。

参考:
FLAT: Chinese NER Using Flat-Lattice Transformer (github.com)
NLP项目实践——中文序列标注Flat Lattice代码解读、运行与使用_CSDN博客

你可能感兴趣的:(FLAT代码解读(2)-模型)