©PaperWeekly 原创 · 作者 | 苏剑林
单位 | 追一科技
研究方向 | NLP、神经网络
两年前,在百度的“2019 语言与智能技术竞赛”(下称 LIC2019)中,笔者提出了一个新的关系抽取模型(参考《基于 DGCNN 和概率图的轻量级信息抽取模型》),后被进一步发表和命名为“CasRel”,算是当时关系抽取的 SOTA。然而,CasRel 提出时笔者其实也是首次接触该领域,所以现在看来 CasRel 仍有诸多不完善之处,笔者后面也有想过要进一步完善它,但也没想到特别好的设计。
后来,笔者提出了 GlobalPointer 以及近日的 Efficient GlobalPointer,感觉有足够的“材料”来构建新的关系抽取模型了。于是笔者从概率图思想出发,参考了 CasRel 之后的一些 SOTA 设计,最终得到了一版类似 TPLinker 的模型。
基础思路
关系抽取乍看之下是三元组 (即 subject, predicate, object)的抽取,但落到具体实现上,它实际是“五元组” 的抽取,其中 分别是 的首、尾位置,而 则分别是 的首、尾位置。
从概率图的角度来看,我们可以这样构建模型:
1. 设计一个五元组的打分函数 ;
2. 训练时让标注的五元组 ,其余五元组则 ;
3. 预测时枚举所有可能的五元组,输出 的部分。
然而,直接枚举所有的五元组数目太多,假设句子长度为 , 的总数为 ,即便加上 和 的约束,所有五元组的数目也有
这是长度的四次方级别的计算量,实际情况下难以实现,所以必须做一些简化。
简化分解
以我们目前的算力来看,一般最多也就能接受长度平方级别的计算量,所以我们每次顶多能识别“一对”首或尾,为此,我们可以用以下的分解:
要注意的是,该等式属于模型假设,是基于我们对任务的理解以及算力的限制所设计出来的,而不是理论推导出来的。其中,每一项都具直观的意义,比如 、 分别是 subject、object 的首尾打分,通过 和 来析出所有的 subject 和 object。至于后两项,则是 predicate 的匹配, 这一项代表以 subject 和 object 的首特征作为它们自身的表征来进行一次匹配,如果我们能确保 subject 内和 object 内是没有嵌套实体的,那么理论上 就足够析出所有的 predicate 了,但考虑到存在嵌套实体的可能,所以我们还要对实体的尾再进行一次匹配,即 这一项。
此时,训练和预测过程变为:
1. 训练时让标注的五元组 、、、,其余五元组则 、、、;
2. 预测时枚举所有可能的五元组,逐次输出 、、、 的部分,然后取它们的交集作为最终的输出(即同时满足 4 个条件)。
在实现上,由于 、 是用来识别 subject、object 对应的实体的,它相当于有两种实体类型的 NER 任务,所以我们可以用一个 GlobalPointer 来完成;至于 ,它是用来识别 predicate 为 的 对,跟 NER 不同的是,NER 有 的约束而它没有,这里我们同样用 GlobalPointer 来完成,但为了识别出 的部分,要去掉 GlobalPointer 默认的下三角 mask;最后 跟 同理,不再赘述。
这里再回顾一遍:我们知道,作为 NER 模块,GlobalPointer 可以统一识别嵌套和非嵌套的实体,而这是它基于 token-pair 的识别来做到的。所以,我们应该进一步将 GlobalPointer 理解为一个 token-pair 的识别模型,而不是局限在 NER 范围内理解它。认识到这一点之后,我们就能明白上述 、、、 其实都可以用 GlobalPointer 来实现了,而要不要加下三角 mask,则自行根据具体任务背景设置就好。
损失函数
现在我们已经把打分函数都设计好了,那么为了训练模型,就差损失函数了。这里继续使用 GlobalPointer 默认使用的、在《将“softmax+交叉熵”推广到多标签分类问题》中提出的多标签交叉熵,它的一般形式为:
其中 分别是正、负类别的集合。在之前的文章中,我们都是用“multi hot”向量来标记正、负类别的,即如果总类别数为 ,那么我们用一个 维向量来表示,其中正类的位置为 1,负类的位置为 0。然而,在 和 的场景,我们各需要一个 的矩阵来标记,两个加在一起并算上 batch_size 总维度就是 ,以 为例,那么 亿。这也就意味着,如果我们还坚持用“multi hot”的形式表示标签的话,每一步训练我们都要创建一个 1 亿参数量的矩阵,然后还要传到 GPU 中,这样不管是创建还是传输成本都很大。
所以,为了提高训练速度,我们需要实现一个“稀疏版”的多标签交叉熵,即每次都只传输正类所对应的的下标就好,由于正类远远少于负类,这样标签矩阵的尺寸就大大减少了。而“稀疏版”多标签交叉熵,意味着我们要在只知道 和 的前提下去实现式(3)。为此,我们使用的实现方式是:
如果即
这样就通过 和 算出了负类对应的损失,而正类部分的损失保持不变就好。
最后,一般情况下的多标签分类任务正类个数是不定的,这时候我们可以将类的下标从 1 开始,将 0 作为填充标签使得每个样本的标签矩阵大小一致,最后在 loss 的实现上对 0 类进行 mask 处理即可。相应的实现已经内置在 bert4keras 中,详情可以参考“sparse_multilabel_categorical_crossentropy” [1]。
实验结果
为了方便称呼,我们暂且将上述模型称为 GPLinker(GlobalPointer-based Linking),一个基于 bert4keras 的参考实现如下:
脚本链接:task_relation_extraction_gplinker.py [2]
在 LIC2019 上的实验结果如下(CasRel 的代码为 task_relation_extraction.py [3]):
预训练模型是 BERT base,Standard 和 Efficient 的区别是分别使用了标准版GlobalPointer 和 Efficient GlobalPointer。该实验结果说明了两件事情,一是 GPLinker 确实比 CasRel 更加有效,二是 Efficient GlobalPointer 的设计确实能在更少参数的情况下媲美标准版 GlobalPointer 的效果。要知道在 LIC2019 这个任务下,如果使用标准版 GlobalPointer,那么 GPLinker 的参数量接近 1 千万,而用 Efficient GlobalPointer 的话只有 30 万左右。
此外,在 3090 上,相比于“multi hot”版的多标签交叉熵,使用稀疏版多标签交叉熵的模型在训练速度上能提高 1.5 倍而不会损失精度,跟 CasRel 相比,使用了稀疏版多标签交叉熵的 GPLinker 在训练速度上只慢 15%,但是解码速度快将近一倍,算得上又快又好了。
相关工作
而对于了解这两年关系抽取 SOTA 模型进展的同学来说,理解上述模型后,会发现它跟 TPLinker [4] 是非常相似的。确实如此,模型在设计之初确实充分借鉴了 TPLinker,最后的结果也同样跟 TPLinker 很相似。
大体上来说,TPLinker 与 GPLinker 的区别如下:
1. TPLinker 的 token-pair 分类特征是首尾特征后拼接做 Dense 变换得到的,其思想来源于 Additive Attention;GPLinker 则是用 GlobalPointer 实现,其思想来源于 Scaled Dot-Product Attention。平均来说,后者拥有更少的显存占用和更快的计算速度。
2. GPLinker 分开识别 subject 和 object 的实体,而 TPLinker 将 subject 和 object 混合起来统一识别。笔者也在 GPLinker 中尝试了混合识别,发现最终效果跟分开识别没有明显区别。
3. 在 和 ,TPLinker 将其转化为了 个 3 分类问题,这会有明显的类别不平衡问题;而 GPLinker 用到了笔者提出的多标签交叉熵,则不会存在不平衡问题,更容易训练。事实上后来 TPLinker 也意识到了这个问题,并提出了 TPLinker-plus [5],其中也用到了该多标签交叉熵。
当然,在笔者看来,本文的最主要贡献,并不是提出 GPLinker 的这些改动,而是对关系联合抽取模型进行一次“自上而下”的理解:从开始的五元组打分 出发,分析其难处,然后简化分解式(2)来“逐个击破”。希望这个自上而下的理解过程,能给读者在为更复杂的任务设计模型时提供一定的思路。
文章小结
本文分享了一个基于 GlobalPointer 的实体关系联合抽取模型——“GPLinker”,并提供了一个“自上而下”的推导理解给大家参考。
参考文献
[1] https://github.com/bojone/bert4keras/blob/4dcda150b54ded71420c44d25ff282ed30f3ea42/bert4keras/backend.py#L272
[2] https://github.com/bojone/bert4keras/tree/master/examples/task_relation_extraction_gplinker.py
[3] https://github.com/bojone/bert4keras/tree/master/examples/task_relation_extraction.py
[4] https://arxiv.org/abs/2010.13415
[5] https://github.com/131250208/TPlinker-joint-extraction/tree/master/tplinker_plus
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
投稿通道:
• 投稿邮箱:[email protected]
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
·