FLAT代码解读(3)-输出

论文 FLAT: Chinese NER Using Flat-Lattice Transformer(ACL 2020)

在前两篇中,我们对FLAT模型的输入和网络结构的关键代码进行了解读。本篇我们分析模型的输出以及评价指标。

在上一篇模型介绍中,我们得知,模型的输出pred会送给CRF层计算loss,即:

      pred = self.output(encoded)
      mask = seq_len_to_mask(seq_len).bool()

      if self.training:
          loss = self.crf(pred, target, mask).mean(dim=0)
          return {'loss': loss}
      else:
          # 作者将scores命名为path, 应当为笔误,这里改过来
          pred, scores = self.crf.viterbi_decode(pred, mask)
          result = {'pred': pred}
          return result

这里self.crf()的具体代码如下:

self.crf = get_crf_zero_init(self.label_size)

def get_crf_zero_init(label_size, include_start_end_trans=False, 
                      allowed_transitions=None, initial_method=None):
    import torch.nn as nn
    from fastNLP.modules import ConditionalRandomField
    crf = ConditionalRandomField(label_size, include_start_end_trans)

    crf.trans_m = nn.Parameter(torch.zeros(size=[label_size, label_size], requires_grad=True))
    if crf.include_start_end_trans:
        crf.start_scores = nn.Parameter(torch.zeros(size=[label_size], requires_grad=True))
        crf.end_scores = nn.Parameter(torch.zeros(size=[label_size], requires_grad=True))
    return crf

可以发现,这里调用了FastNLP工具包里的ConditionalRandomField类,该类提供了forward()以及viterbi_decode()两个方法,分别用于train和inference。

最后可以看到,作者采用的评价指标为:

f1_metric = SpanFPreRecMetric(vocabs['label'], pred='pred', target='target', 
                              seq_len='seq_len', encoding_type=encoding_type)
acc_metric = AccuracyMetric(pred='pred', target='target', seq_len='seq_len')
acc_metric.set_metric_name('label_acc')
metrics = [
    f1_metric,
    acc_metric
]

这里的SpanFPreRecMetricAccuracyMetric也是FastNLP工具包里类。

  • SpanFPreRecMetric以span的方式计算F1, precision, recall
  • AccuracyMetric 计算accuracy,这里我理解为是计算token-level的acc

至此,我们已基本解读完FLAT官方开源代码的关键细节,若有不当及错误,欢迎批评指正!

参考:
FLAT: Chinese NER Using Flat-Lattice Transformer (github.com)

你可能感兴趣的:(FLAT代码解读(3)-输出)