SpERT是一个以bert的预训练语言模型为基础,进行联合实体识别和关系抽取的模型。文章设计了一个联合实体识别和关系抽取模型架构,并使用基于跨度的负采样形式,在ADE、CoNLL04和SciERC三个数据集上均达到了sota效果。本文将结合论文以及主体源码对模型进行解读,论文与全部源码详见参考资料
模型主要由 span classification 、 Span Filtering 和 relation classification 三部分组成。 span classification 和 Span Filtering 对实体进行筛选和识别,relation classification 进行关系抽取。模型架构如图所示:
模型的第一部分和第二部分主要进行的是获得实体的向量表示,并使用其进行实体识别,同时过滤掉未被识别出类型的实体。
首先,模型使用bert获取文本的向量表示 ( e 1 , e 2 , . . . e n , c ) (e_1,e_2,...e_n,c) (e1,e2,...en,c) 。之后,模型将在任意跨度检测实体,这里的 c c c 指的是特殊标记CLS代表的词向量。
例如对于句子 (we,will,rock,you) ,可能被检测的实体有(we),(we,will),(will,rock,you)等等。而与之前的一些工作不同的是,本模型并不会对实体和关系假设进行束搜索,而是设定了一个最大值 N e N_e Ne (文章中设定为100),即在所有可能的实体中最多随机选取 N e N_e Ne 个实体,并将未在训练集中被标注为正例的样本标记成负例。同时规定实体的长度不能超过10。这样实体识别的复杂度被限制到 O ( n ) O(n) O(n) 。
实体采样源码的实现如下:
token_count = len(doc.tokens)
# positive entities
pos_entity_spans, pos_entity_types, pos_entity_masks, pos_entity_sizes = [], [], [], []
for e in doc.entities:
pos_entity_spans.append(e.span)
pos_entity_types.append(e.entity_type.index)
pos_entity_masks.append(create_entity_mask(*e.span, context_size))
pos_entity_sizes.append(len(e.tokens))
# negative entities
neg_entity_spans, neg_entity_sizes = [], []
for size in range(1, max_span_size + 1):
for i in range(0, (token_count - size) + 1):
span = doc.tokens[i:i + size].span
if span not in pos_entity_spans:
neg_entity_spans.append(span)
neg_entity_sizes.append(size)
# sample negative entities
neg_entity_samples = random.sample(list(zip(neg_entity_spans, neg_entity_sizes)),
min(len(neg_entity_spans), neg_entity_count))
neg_entity_spans, neg_entity_sizes = zip(*neg_entity_samples) if neg_entity_samples else ([], [])
neg_entity_masks = [create_entity_mask(*span, context_size) for span in neg_entity_spans]
neg_entity_types = [0] * len(neg_entity_spans)
这里的doc就是文本,max_entity_sizes就是10,pos_entity_spans就是训练集中标注为正例的样例,neg_entity_count就是 N e N_e Ne (100)。即选所有可能的跨度,选取在训练集中被标注的样本作为正例,并选取未在训练集中被标注为正例的跨度作为作为负例。如果选取的数量大于最大值 N e N_e Ne ,就在其中随机选取。负采样的实体将会被标注为无类别实体 n o n e none none ,该类别用 0 表示。
在选取完实体后,将为其创建遮蔽矩阵,源码如下:
def create_entity_mask(start, end, context_size):
mask = torch.zeros(context_size, dtype=torch.bool)
mask[start:end] = 1
return mask
在得到正负样例后,将其聚合得到实体样例集合,源码如下:
entity_types = pos_entity_types + neg_entity_types
entity_masks = pos_entity_masks + neg_entity_masks
entity_sizes = pos_entity_sizes + list(neg_entity_sizes)
选取好可能的实体后,下面就是对其向量表示进行处理。将被送入span classifier的向量表示由三部分构成,分别为实体包含的token的向量表示(在模型图中用红的部分表示)、宽度嵌入(蓝色的部分)以及特殊标记CLS(绿色的部分)。
在第一部分中,对于一个可能的实体跨度 ( e i , e i + 1 , . . . , e i + k ) (e_i,e_{i+1},...,e_{i+k}) (ei,ei+1,...,ei+k) (就是以 ( e i , e i + 1 , . . . , e i + k ) (e_i,e_{i+1},...,e_{i+k}) (ei,ei+1,...,ei+k)为下标,在文本中选择一段连续的文本,在模型图中被表示为红色的部分),其向量表示为 f ( e i , e i + 1 , . . . , e i + k ) f(e_i,e_{i+1},...,e_{i+k}) f(ei,ei+1,...,ei+k) ,这里的 f f f 使用了最大池化。至此,模型得到了实体token的向量表示,公式如下:
在第二部分中,宽度嵌入是在训练中学习到的嵌入矩阵,即实体的宽度为 k + 1 k+1 k+1表示实体中包含 k + 1 k+1 k+1个token,那么实体的宽度嵌入 w k + 1 w_{k+1} wk+1 就会被表示为以 k + 1 k+1 k+1为下标,在宽度矩阵中进行索引得到的宽度为 k + 1 k+1 k+1的向量表示。将宽度表示与实体token的向量表示连接,公式如下:
e ( s ) : = f ( e i , e i + 1 , . . . , e i + k ) ∘ w k + 1 \mathbf e(s) :=f(\mathbf e_i,\mathbf e_{i+1},...,\mathbf e_{i+k}) \circ \mathbf w_{k+1} e(s):=f(ei,ei+1,...,ei+k)∘wk+1
在解读第三部分前,先回忆一下bert模型。bert对于特殊标记CLS代表的词向量进行了池化操作,得到最终的CLS向量表示。而在本文中使用的CLS是未经池化的向量 c \mathbf c c。
将这三部分连接,得到了最终的向量表示,公式如下:
x s : = e ( s ) ∘ c \mathbf x^s :=\mathbf e(s) \circ \mathbf c xs:=e(s)∘c
最后,将实体表示送入一个全连接加softmax后,得到了实体的类别,其中也包括了无类别 n o n e none none 。公式如下:
y s ^ = s o f t m a x ( W s ⋅ x s + b s ) \hat \mathbf {y^s} = softmax \Bigr ( \mathit W^s \cdot\mathbf x^s+\mathbf b^s \Bigr ) ys^=softmax(Ws⋅xs+bs)
本部分的源码如下所示:
self.entity_classifier = nn.Linear(config.hidden_size * 2 + size_embedding, entity_types)
def _classify_entities(self, encodings, h, entity_masks, size_embeddings):
# max pool entity candidate spans
m = (entity_masks.unsqueeze(-1) == 0).float() * (-1e30)
entity_spans_pool = m + h.unsqueeze(1).repeat(1, entity_masks.shape[1], 1, 1)
entity_spans_pool = entity_spans_pool.max(dim=2)[0]
# get cls token as candidate context representation
entity_ctx = get_token(h, encodings, self._cls_token)
# create candidate representations including context, max pooled span and size embedding
entity_repr = torch.cat([entity_ctx.unsqueeze(1).repeat(1, entity_spans_pool.shape[1], 1),
entity_spans_pool, size_embeddings], dim=2)
entity_repr = self.dropout(entity_repr)
# classify entity candidates
entity_clf = self.entity_classifier(entity_repr)
return entity_clf, entity_spans_pool
这里的entity_mask矩阵是在采样的过程中构建的,即选择文本中一段连续的token作为实体是,将这部分设为1,其他部分设为0 。在_classify_entities函数中,把entity_mask中为0的项设为-1e30得到遮蔽矩阵m。先将文本向量h重复entity_masks.shape[1]
次,即有实体样例集合中有多少个样本,就将h重复多少次。将其与m相加时,每个样本对应的实体遮蔽矩阵都会与一个文本向量表示相加。这样文本中不在实体里的token会被减去一个很大的值,从而在entity_spans_pool = entity_spans_pool.max(dim=2)[0]
进行最大池化的过程中由于值很小不会影响池化的结果。之后通过下标索引的方式,获得了未经池化的CLS向量表示entity_ctx。将实体表示、宽度嵌入以及CLS连接后,过一个全连接即可得到实体分类的结果。
在模型的第三部分Relation Classification中,模型构造了可能的关系并对其进行分类。
模型首先从可能的实体中随机选择最多 N r N_r Nr 对实体组成关系集合(这里的 N r N_r Nr 也为100)。对于一个由一个实体对 ( s 1 , s 2 ) (s_1,s_2) (s1,s2) 构成的实体,其关系向量表示由两部分构成。一部分是上面第一个公式得到的实体向量表示 e ( s 1 ) , e ( s 2 ) \mathbf e(s_1),\mathbf e(s_2) e(s1),e(s2) ,也就是模型图中红色的部分;
# positive relations
pos_rels, pos_rel_spans, pos_rel_types, pos_rel_masks = [], [], [], []
for rel in doc.relations:
s1, s2 = rel.head_entity.span, rel.tail_entity.span
pos_rels.append((pos_entity_spans.index(s1), pos_entity_spans.index(s2)))
pos_rel_spans.append((s1, s2))
pos_rel_types.append(rel.relation_type)
pos_rel_masks.append(create_rel_mask(s1, s2, context_size))
# negative relations
# use only strong negative relations, i.e. pairs of actual (labeled) entities that are not related
neg_rel_spans = []
for i1, s1 in enumerate(pos_entity_spans):
for i2, s2 in enumerate(pos_entity_spans):
rev = (s2, s1)
rev_symmetric = rev in pos_rel_spans and pos_rel_types[pos_rel_spans.index(rev)].symmetric
# do not add as negative relation sample:
# neg. relations from an entity to itself
# entity pairs that are related according to gt
# entity pairs whose reverse exists as a symmetric relation in gt
if s1 != s2 and (s1, s2) not in pos_rel_spans and not rev_symmetric:
neg_rel_spans.append((s1, s2))
# sample negative relations
neg_rel_spans = random.sample(neg_rel_spans, min(len(neg_rel_spans), neg_rel_count))
neg_rels = [(pos_entity_spans.index(s1), pos_entity_spans.index(s2)) for s1, s2 in neg_rel_spans]
neg_rel_masks = [create_rel_mask(*spans, context_size) for spans in neg_rel_spans]
neg_rel_types = [0] * len(neg_rel_spans)
这里的首先获得非零实体 non_zero_indices ,0 表示不是已知实体类型,未检测出类型的实体将会被过滤。之后再从已检测到的实体中两两选择关系。这里的symmetric是关系的一个属性,表示这个关系是否是对称的。在创建负例的时候,不会创建实体到自身的关系、在训练集中被标注为正例的关系和这个关系的对称关系(如果存在对称关系的话)。
与上面选择实体相同的是,neg_rel_count代表 N r N_r Nr ,即负采样的关系数目超过 N r N_r Nr 的话就会在其中随机采样。负采样得到的关系将会被标注为 n o n e none none ,即句子中两个实体之间不存在关系,该关系类别用 0 表示。
与实体选择类似,这里也会创建遮蔽矩阵,源码如下:
def create_rel_mask(s1, s2, context_size):
start = s1[1] if s1[1] < s2[0] else s2[1]
end = s2[0] if s1[1] < s2[0] else s1[0]
mask = create_entity_mask(start, end, context_size)
return mask
在进行完关系构建与选择后,将正负样例聚合得到关系集合,源码如下:
rels = pos_rels + neg_rels
rel_types = [r.index for r in pos_rel_types] + neg_rel_types
rel_masks = pos_rel_masks + neg_rel_masks
除了实体特征以外,关系抽取也要依赖文本特征。由于特殊标记CLS有文本分类的作用,关系抽取的模型架构往往会使用CLS所代表的词向量作为关系抽取的输入之一。而在本文中,并没有选择CLS作为文本特征,而是对于两个实体之间的文本进行了最大池化 ,得到了文本特征的向量表示 c ( s 1 , s 2 ) \mathbf c(s_1,s_2) c(s1,s2) ,也就是模型图中黄色的部分。如果两个实体之间没有文本,那么 c ( s 1 , s 2 ) \mathbf c(s_1,s_2) c(s1,s2) 将被设置为0 。
至此,我们得到了关系的向量表示,由于关系往往是非对称的,所以每一个实体对将会得到两个关系表示。公式如下:
x 1 r : = e ( s 1 ) ∘ c ( s 1 , s 2 ) ∘ e ( s 2 ) \mathbf x^r_1 := \mathbf e(s_1) \circ \mathbf c(s_1,s_2) \circ \mathbf e(s_2) x1r:=e(s1)∘c(s1,s2)∘e(s2)
x 2 r : = e ( s 2 ) ∘ c ( s 1 , s 2 ) ∘ e ( s 1 ) \mathbf x^r_2 := \mathbf e(s_2) \circ \mathbf c(s_1,s_2) \circ \mathbf e(s_1) x2r:=e(s2)∘c(s1,s2)∘e(s1)
接下来这两个关系将会过一个全连接后再用一个sigmoid激活,公式如下:
y ^ 1 / 2 r : = σ ( W r ⋅ x 1 / 2 r + b r ) \hat \mathbf y^r_{1/2} := \sigma \Bigr ( W^r \cdot \mathbf x^r_{1/2} +\mathbf b^r \Bigr) y^1/2r:=σ(Wr⋅x1/2r+br)
最后,模型的损失是实体分类损失 L s \mathcal L^s Ls 与关系分类损失 L r \mathcal L^r Lr 之和,公式如下:
L = L s + L r \mathcal L = \mathcal L^s +\mathcal L^r L=Ls+Lr
至此,模型整体的架构已经比较清楚了,下面我们来看一下关系构建与选择部分的代码
经过关系构建与选择后,我们得到了候选的关系。接下来我们看一下关系分类部分的源码:
def batch_index(tensor, index, pad=False):
if tensor.shape[0] != index.shape[0]:
raise Exception()
if not pad:
return torch.stack([tensor[i][index[i]] for i in range(index.shape[0])])
else:
return padded_stack([tensor[i][index[i]] for i in range(index.shape[0])])
self.rel_classifier = nn.Linear(config.hidden_size * 3 + size_embedding * 2, relation_types)
# get pairs of entity candidate representations
entity_pairs = util.batch_index(entity_spans, relations)
entity_pairs = entity_pairs.view(batch_size, entity_pairs.shape[1], -1)
# get corresponding size embeddings
size_pair_embeddings = util.batch_index(size_embeddings, relations)
size_pair_embeddings = size_pair_embeddings.view(batch_size, size_pair_embeddings.shape[1], -1)
# relation context (context between entity candidate pair)
# mask non entity candidate tokens
m = ((rel_masks == 0).float() * (-1e30)).unsqueeze(-1)
rel_ctx = m + h
# max pooling
rel_ctx = rel_ctx.max(dim=2)[0]
# set the context vector of neighboring or adjacent entity candidates to zero
rel_ctx[rel_masks.to(torch.uint8).any(-1) == 0] = 0
# create relation candidate representations including context, max pooled entity candidate pairs
# and corresponding size embeddings
rel_repr = torch.cat([rel_ctx, entity_pairs, size_pair_embeddings], dim=2)
rel_repr = self.dropout(rel_repr)
# classify relation candidates
chunk_rel_logits = self.rel_classifier(rel_repr)
return chunk_rel_logits
在关系分类的过程中,首先根据下标获得了实体对的向量表示和宽度嵌入,也就是模型图中红色和蓝色的部分。之后将关系遮蔽矩阵与文本向量表示h相加,得到文本特征。与实体分类的处理类似,这里的h也是重复过n次的(这个处理由于篇幅问题就没有贴出来,更详细的处理可以查看github上的源码)。将这两个实体特征与文本特征连接后,过一个全连接后,得到了最终的分类结果。
SpERT的主要创新点在于抛弃了传统的BIO/BILOU标注实体的方式,构建了一个基于跨度的联合实体识别和关系抽取模型。同时使用了负采样的方式增强模型,并通过最大池化的方式,提取了一个关系中实体对之间的文本特征。
论文:
https://arxiv.org/abs/1909.07755
源码:
https://github.com/markus-eberts/spert/blob/master/README.md