bert关系抽取论文源码之SpERT:Span-based Joint Entity and Relation Extraction with Transformer Pre-training

目录

  • 前言
  • 模型架构
    • Span Classification & Span Filtering
      • 实体选择以及负采样
      • 实体表示及分类
    • Relation Classification
      • 关系构造以及负采样
      • 文本特征与关系分类
  • 结语
  • 参考资料

前言

SpERT是一个以bert的预训练语言模型为基础,进行联合实体识别和关系抽取的模型。文章设计了一个联合实体识别和关系抽取模型架构,并使用基于跨度的负采样形式,在ADE、CoNLL04和SciERC三个数据集上均达到了sota效果。本文将结合论文以及主体源码对模型进行解读,论文与全部源码详见参考资料

模型架构

模型主要由 span classification 、 Span Filtering 和 relation classification 三部分组成。 span classification 和 Span Filtering 对实体进行筛选和识别,relation classification 进行关系抽取。模型架构如图所示:

bert关系抽取论文源码之SpERT:Span-based Joint Entity and Relation Extraction with Transformer Pre-training_第1张图片

Span Classification & Span Filtering

模型的第一部分和第二部分主要进行的是获得实体的向量表示,并使用其进行实体识别,同时过滤掉未被识别出类型的实体。

实体选择以及负采样

首先,模型使用bert获取文本的向量表示 ( e 1 , e 2 , . . . e n , c ) (e_1,e_2,...e_n,c) (e1,e2,...en,c) 。之后,模型将在任意跨度检测实体,这里的 c c c 指的是特殊标记CLS代表的词向量。

例如对于句子 (we,will,rock,you) ,可能被检测的实体有(we),(we,will),(will,rock,you)等等。而与之前的一些工作不同的是,本模型并不会对实体和关系假设进行束搜索,而是设定了一个最大值 N e N_e Ne (文章中设定为100),即在所有可能的实体中最多随机选取 N e N_e Ne 个实体,并将未在训练集中被标注为正例的样本标记成负例。同时规定实体的长度不能超过10。这样实体识别的复杂度被限制到 O ( n ) O(n) O(n)

实体采样源码的实现如下:


	token_count = len(doc.tokens)

    # positive entities
    pos_entity_spans, pos_entity_types, pos_entity_masks, pos_entity_sizes = [], [], [], []
    for e in doc.entities:
        pos_entity_spans.append(e.span)
        pos_entity_types.append(e.entity_type.index)
        pos_entity_masks.append(create_entity_mask(*e.span, context_size))
        pos_entity_sizes.append(len(e.tokens))
	
    # negative entities
    neg_entity_spans, neg_entity_sizes = [], []
    for size in range(1, max_span_size + 1):
        for i in range(0, (token_count - size) + 1):
            span = doc.tokens[i:i + size].span
            if span not in pos_entity_spans:
                neg_entity_spans.append(span)
                neg_entity_sizes.append(size)
                
    # sample negative entities
    neg_entity_samples = random.sample(list(zip(neg_entity_spans, neg_entity_sizes)),
                                       min(len(neg_entity_spans), neg_entity_count))
    neg_entity_spans, neg_entity_sizes = zip(*neg_entity_samples) if neg_entity_samples else ([], [])
    neg_entity_masks = [create_entity_mask(*span, context_size) for span in neg_entity_spans]
    neg_entity_types = [0] * len(neg_entity_spans)
    

这里的doc就是文本,max_entity_sizes就是10,pos_entity_spans就是训练集中标注为正例的样例,neg_entity_count就是 N e N_e Ne (100)。即选所有可能的跨度,选取在训练集中被标注的样本作为正例,并选取未在训练集中被标注为正例的跨度作为作为负例。如果选取的数量大于最大值 N e N_e Ne ,就在其中随机选取。负采样的实体将会被标注为无类别实体 n o n e none none ,该类别用 0 表示。

在选取完实体后,将为其创建遮蔽矩阵,源码如下:


def create_entity_mask(start, end, context_size):
    mask = torch.zeros(context_size, dtype=torch.bool)
    mask[start:end] = 1
    return mask

在得到正负样例后,将其聚合得到实体样例集合,源码如下:


    entity_types = pos_entity_types + neg_entity_types
    entity_masks = pos_entity_masks + neg_entity_masks
    entity_sizes = pos_entity_sizes + list(neg_entity_sizes)

实体表示及分类

选取好可能的实体后,下面就是对其向量表示进行处理。将被送入span classifier的向量表示由三部分构成,分别为实体包含的token的向量表示(在模型图中用红的部分表示)、宽度嵌入(蓝色的部分)以及特殊标记CLS(绿色的部分)。

在第一部分中,对于一个可能的实体跨度 ( e i , e i + 1 , . . . , e i + k ) (e_i,e_{i+1},...,e_{i+k}) (ei,ei+1,...,ei+k) (就是以 ( e i , e i + 1 , . . . , e i + k ) (e_i,e_{i+1},...,e_{i+k}) (ei,ei+1,...,ei+k)为下标,在文本中选择一段连续的文本,在模型图中被表示为红色的部分),其向量表示为 f ( e i , e i + 1 , . . . , e i + k ) f(e_i,e_{i+1},...,e_{i+k}) f(ei,ei+1,...,ei+k) ,这里的 f f f 使用了最大池化。至此,模型得到了实体token的向量表示,公式如下:

在第二部分中,宽度嵌入是在训练中学习到的嵌入矩阵,即实体的宽度为 k + 1 k+1 k+1表示实体中包含 k + 1 k+1 k+1个token,那么实体的宽度嵌入 w k + 1 w_{k+1} wk+1 就会被表示为以 k + 1 k+1 k+1为下标,在宽度矩阵中进行索引得到的宽度为 k + 1 k+1 k+1的向量表示。将宽度表示与实体token的向量表示连接,公式如下:

e ( s ) : = f ( e i , e i + 1 , . . . , e i + k ) ∘ w k + 1 \mathbf e(s) :=f(\mathbf e_i,\mathbf e_{i+1},...,\mathbf e_{i+k}) \circ \mathbf w_{k+1} e(s):=f(ei,ei+1,...,ei+k)wk+1

在解读第三部分前,先回忆一下bert模型。bert对于特殊标记CLS代表的词向量进行了池化操作,得到最终的CLS向量表示。而在本文中使用的CLS是未经池化的向量 c \mathbf c c

将这三部分连接,得到了最终的向量表示,公式如下:

x s : = e ( s ) ∘ c \mathbf x^s :=\mathbf e(s) \circ \mathbf c xs:=e(s)c

最后,将实体表示送入一个全连接加softmax后,得到了实体的类别,其中也包括了无类别 n o n e none none 。公式如下:

y s ^ = s o f t m a x ( W s ⋅ x s + b s ) \hat \mathbf {y^s} = softmax \Bigr ( \mathit W^s \cdot\mathbf x^s+\mathbf b^s \Bigr ) ys^=softmax(Wsxs+bs)

本部分的源码如下所示:


        self.entity_classifier = nn.Linear(config.hidden_size * 2 + size_embedding, entity_types)

    def _classify_entities(self, encodings, h, entity_masks, size_embeddings):
        # max pool entity candidate spans
        m = (entity_masks.unsqueeze(-1) == 0).float() * (-1e30)
        entity_spans_pool = m + h.unsqueeze(1).repeat(1, entity_masks.shape[1], 1, 1)
        entity_spans_pool = entity_spans_pool.max(dim=2)[0]

        # get cls token as candidate context representation
        entity_ctx = get_token(h, encodings, self._cls_token)

        # create candidate representations including context, max pooled span and size embedding
        entity_repr = torch.cat([entity_ctx.unsqueeze(1).repeat(1, entity_spans_pool.shape[1], 1),
                                 entity_spans_pool, size_embeddings], dim=2)
        entity_repr = self.dropout(entity_repr)

        # classify entity candidates
        entity_clf = self.entity_classifier(entity_repr)

        return entity_clf, entity_spans_pool

这里的entity_mask矩阵是在采样的过程中构建的,即选择文本中一段连续的token作为实体是,将这部分设为1,其他部分设为0 。在_classify_entities函数中,把entity_mask中为0的项设为-1e30得到遮蔽矩阵m。先将文本向量h重复entity_masks.shape[1]次,即有实体样例集合中有多少个样本,就将h重复多少次。将其与m相加时,每个样本对应的实体遮蔽矩阵都会与一个文本向量表示相加。这样文本中不在实体里的token会被减去一个很大的值,从而在entity_spans_pool = entity_spans_pool.max(dim=2)[0]进行最大池化的过程中由于值很小不会影响池化的结果。之后通过下标索引的方式,获得了未经池化的CLS向量表示entity_ctx。将实体表示、宽度嵌入以及CLS连接后,过一个全连接即可得到实体分类的结果。

Relation Classification

在模型的第三部分Relation Classification中,模型构造了可能的关系并对其进行分类。

关系构造以及负采样

模型首先从可能的实体中随机选择最多 N r N_r Nr 对实体组成关系集合(这里的 N r N_r Nr 也为100)。对于一个由一个实体对 ( s 1 , s 2 ) (s_1,s_2) (s1,s2) 构成的实体,其关系向量表示由两部分构成。一部分是上面第一个公式得到的实体向量表示 e ( s 1 ) , e ( s 2 ) \mathbf e(s_1),\mathbf e(s_2) e(s1),e(s2) ,也就是模型图中红色的部分;


    # positive relations
    pos_rels, pos_rel_spans, pos_rel_types, pos_rel_masks = [], [], [], []
    for rel in doc.relations:
        s1, s2 = rel.head_entity.span, rel.tail_entity.span
        pos_rels.append((pos_entity_spans.index(s1), pos_entity_spans.index(s2)))
        pos_rel_spans.append((s1, s2))
        pos_rel_types.append(rel.relation_type)
        pos_rel_masks.append(create_rel_mask(s1, s2, context_size))

    # negative relations
    # use only strong negative relations, i.e. pairs of actual (labeled) entities that are not related
    neg_rel_spans = []

    for i1, s1 in enumerate(pos_entity_spans):
        for i2, s2 in enumerate(pos_entity_spans):
            rev = (s2, s1)
            rev_symmetric = rev in pos_rel_spans and pos_rel_types[pos_rel_spans.index(rev)].symmetric

            # do not add as negative relation sample:
            # neg. relations from an entity to itself
            # entity pairs that are related according to gt
            # entity pairs whose reverse exists as a symmetric relation in gt
            if s1 != s2 and (s1, s2) not in pos_rel_spans and not rev_symmetric:
                neg_rel_spans.append((s1, s2))

    # sample negative relations
    neg_rel_spans = random.sample(neg_rel_spans, min(len(neg_rel_spans), neg_rel_count))
    neg_rels = [(pos_entity_spans.index(s1), pos_entity_spans.index(s2)) for s1, s2 in neg_rel_spans]
    neg_rel_masks = [create_rel_mask(*spans, context_size) for spans in neg_rel_spans]
    neg_rel_types = [0] * len(neg_rel_spans)

这里的首先获得非零实体 non_zero_indices ,0 表示不是已知实体类型,未检测出类型的实体将会被过滤。之后再从已检测到的实体中两两选择关系。这里的symmetric是关系的一个属性,表示这个关系是否是对称的。在创建负例的时候,不会创建实体到自身的关系、在训练集中被标注为正例的关系和这个关系的对称关系(如果存在对称关系的话)。
与上面选择实体相同的是,neg_rel_count代表 N r N_r Nr ,即负采样的关系数目超过 N r N_r Nr 的话就会在其中随机采样。负采样得到的关系将会被标注为 n o n e none none ,即句子中两个实体之间不存在关系,该关系类别用 0 表示。

与实体选择类似,这里也会创建遮蔽矩阵,源码如下:


def create_rel_mask(s1, s2, context_size):
    start = s1[1] if s1[1] < s2[0] else s2[1]
    end = s2[0] if s1[1] < s2[0] else s1[0]
    mask = create_entity_mask(start, end, context_size)
    return mask

在进行完关系构建与选择后,将正负样例聚合得到关系集合,源码如下:


    rels = pos_rels + neg_rels
    rel_types = [r.index for r in pos_rel_types] + neg_rel_types
    rel_masks = pos_rel_masks + neg_rel_masks

文本特征与关系分类

除了实体特征以外,关系抽取也要依赖文本特征。由于特殊标记CLS有文本分类的作用,关系抽取的模型架构往往会使用CLS所代表的词向量作为关系抽取的输入之一。而在本文中,并没有选择CLS作为文本特征,而是对于两个实体之间的文本进行了最大池化 ,得到了文本特征的向量表示 c ( s 1 , s 2 ) \mathbf c(s_1,s_2) c(s1,s2) ,也就是模型图中黄色的部分。如果两个实体之间没有文本,那么 c ( s 1 , s 2 ) \mathbf c(s_1,s_2) c(s1,s2) 将被设置为0 。

至此,我们得到了关系的向量表示,由于关系往往是非对称的,所以每一个实体对将会得到两个关系表示。公式如下:

x 1 r : = e ( s 1 ) ∘ c ( s 1 , s 2 ) ∘ e ( s 2 ) \mathbf x^r_1 := \mathbf e(s_1) \circ \mathbf c(s_1,s_2) \circ \mathbf e(s_2) x1r:=e(s1)c(s1,s2)e(s2)

x 2 r : = e ( s 2 ) ∘ c ( s 1 , s 2 ) ∘ e ( s 1 ) \mathbf x^r_2 := \mathbf e(s_2) \circ \mathbf c(s_1,s_2) \circ \mathbf e(s_1) x2r:=e(s2)c(s1,s2)e(s1)

接下来这两个关系将会过一个全连接后再用一个sigmoid激活,公式如下:

y ^ 1 / 2 r : = σ ( W r ⋅ x 1 / 2 r + b r ) \hat \mathbf y^r_{1/2} := \sigma \Bigr ( W^r \cdot \mathbf x^r_{1/2} +\mathbf b^r \Bigr) y^1/2r:=σ(Wrx1/2r+br)

最后,模型的损失是实体分类损失 L s \mathcal L^s Ls 与关系分类损失 L r \mathcal L^r Lr 之和,公式如下:

L = L s + L r \mathcal L = \mathcal L^s +\mathcal L^r L=Ls+Lr

至此,模型整体的架构已经比较清楚了,下面我们来看一下关系构建与选择部分的代码

经过关系构建与选择后,我们得到了候选的关系。接下来我们看一下关系分类部分的源码:


def batch_index(tensor, index, pad=False):
    if tensor.shape[0] != index.shape[0]:
        raise Exception()

    if not pad:
        return torch.stack([tensor[i][index[i]] for i in range(index.shape[0])])
    else:
        return padded_stack([tensor[i][index[i]] for i in range(index.shape[0])])

self.rel_classifier = nn.Linear(config.hidden_size * 3 + size_embedding * 2, relation_types)

        # get pairs of entity candidate representations
        entity_pairs = util.batch_index(entity_spans, relations)
        entity_pairs = entity_pairs.view(batch_size, entity_pairs.shape[1], -1)

        # get corresponding size embeddings
        size_pair_embeddings = util.batch_index(size_embeddings, relations)
        size_pair_embeddings = size_pair_embeddings.view(batch_size, size_pair_embeddings.shape[1], -1)

        # relation context (context between entity candidate pair)
        # mask non entity candidate tokens
        m = ((rel_masks == 0).float() * (-1e30)).unsqueeze(-1)
        rel_ctx = m + h
        # max pooling
        rel_ctx = rel_ctx.max(dim=2)[0]
        # set the context vector of neighboring or adjacent entity candidates to zero
        rel_ctx[rel_masks.to(torch.uint8).any(-1) == 0] = 0

        # create relation candidate representations including context, max pooled entity candidate pairs
        # and corresponding size embeddings
        rel_repr = torch.cat([rel_ctx, entity_pairs, size_pair_embeddings], dim=2)
        rel_repr = self.dropout(rel_repr)

        # classify relation candidates
        chunk_rel_logits = self.rel_classifier(rel_repr)
        return chunk_rel_logits

在关系分类的过程中,首先根据下标获得了实体对的向量表示和宽度嵌入,也就是模型图中红色和蓝色的部分。之后将关系遮蔽矩阵与文本向量表示h相加,得到文本特征。与实体分类的处理类似,这里的h也是重复过n次的(这个处理由于篇幅问题就没有贴出来,更详细的处理可以查看github上的源码)。将这两个实体特征与文本特征连接后,过一个全连接后,得到了最终的分类结果。

结语

SpERT的主要创新点在于抛弃了传统的BIO/BILOU标注实体的方式,构建了一个基于跨度的联合实体识别和关系抽取模型。同时使用了负采样的方式增强模型,并通过最大池化的方式,提取了一个关系中实体对之间的文本特征。

参考资料

论文:
https://arxiv.org/abs/1909.07755

源码:
https://github.com/markus-eberts/spert/blob/master/README.md

你可能感兴趣的:(关系抽取论文阅读笔记,nlp,自然语言处理,深度学习,pytorch,人工智能)