REALM: Retrieval-Augmented Language Model Pre Training 解读

知识就是力量

培根

 

背景

去年可以说是语言模型快速发展的一年,BERT、XLNET、Albert等等模型不断刷新各个NLP榜单。在NLP榜单中比较引人注目的应该属于阅读理解型的任务,例如SQuAD等等。以SQuAD为例,模型需要阅读一段给定的文本,然后回答几个问题,问题如果存在答案,答案一定可以在文章中找到。所以说虽然叫阅读理解,但其实和序列标注有点相像,是在给定序列中标出答案段。而这篇论文针对的问题叫开放领域问答(Open-domain QA),对于一个问题Q,模型需要从包含大量文档的知识库中找到答案,而不是像SQuAD数据集一样从一篇文章中寻找。

大部分的语言模型都采用一种称为masked language model,简称MLM的任务来训练,让模型学会类似完形填空一样的能力。通过在大规模语料上的训练,预训练语言模型如BERT实际上已经隐含了一些知识。例如输入一句“The       is the currency of the United Kingdom”,BERT很有可能会填入单词"pound"。虽然他还是根据词的共现信息学习和推理的,但看上去就像具有所谓的知识一样。从去年开始就有越来越多的研究从单纯语言模型转换为带有知识嵌入的语言模型,例如清华和百度提出的两个同名模型ERNIE。

但上面说的这种隐含知识不好把握,也难以扩展。这篇论文则提出了一种更加模块化且可解释性更强的知识嵌入方法。总的来说,他的方法是训练一个独立的“语境知识抽取器”(contextual knowledge retriever),通过这个抽取器来决定应该在推理时使用哪些知识。而且这个抽取器和语言模型一起进行非监督预训练大大提高模型性能。

方法

REALM: Retrieval-Augmented Language Model Pre Training 解读_第1张图片

 

如上图所示,整篇论文涉及两个任务,左边是语言模型预训练任务MLM,右边是QA任务。下图是预训练任务一个更加完整的流程图,我们由此切入进行介绍。

REALM: Retrieval-Augmented Language Model Pre Training 解读_第2张图片 标题

整个过程分为两个关键步骤。先看第一步,即neural knowledge retriever,它负责计算p(z|x)。要实现这个过程首先需要对z和x进行编码。论文采用的是BERT,对于问题x,直接输入BERT,取[CLS] token的输出作为编码向量,而对于文档z,则将标题和正文用[SEP]连接后输入BERT,同样去[CLS] token的输出。论文中还对BERT的输出向量进行了降维处理。即

 

 

对于文档库中的某一个z,则

 

 

其中f是问题和文档的相关性,

 

 

以上部分就称为neural knowledge retriever,通过他每篇z都会得到一个p。现在可以进行第二步,综合x和z求y。上图是一个预训练的例子,y是抠掉的词。利用z的方式是将z的正文和x拼在一起来提供上下文信息,然后优化下面的目标

 

REALM: Retrieval-Augmented Language Model Pre Training 解读_第3张图片

 

其中j指第j个masked token。

 

在做QA的时候稍有不同。由于此时是针对某个具体的z,所以作者将开放域问答任务退化成了像SQuAD一样在文档中找答案的阅读理解任务。

 

REALM: Retrieval-Augmented Language Model Pre Training 解读_第4张图片

 

这一部分就是knowledge-augmented encoder

训练

上面已经描述了预训练阶段和QA finetune阶段的任务。训练的过程都是最大化正确y对应的logp(y|z,x),而且以上描述的两个任务都是可以端到端优化的。

但这里面对一个问题,上面有个公式需要对整个知识库中所有的文档z的相关概率求和,这是很困难的。作者提出将这一步用只对概率最高的k个文档计算来近似,因为绝大部分文档由于与问题不相关,p(z|x)都非常小。但问题还没有解决,如何找到概率最高的k个文档呢。

观察公式可以发现p(z|x)是正比于两个编码后的内积的,由于大家的分母都一样,分子的顺序就是整个分数的顺序。所以可以用最大内积搜索算法(Maximum Inner Product Search, MIPS,并不知道是什么,维基百科都没有)来解决。但要构建一个快速检索的索引又要求两个编码后的向量是确定的,而由于编码器是不断训练的,所以这个条件无法满足。为了追求一个平衡,作者决定每隔几百步才更新一下编码器,并重新构建索引。而且这只发生在预训练语言模型的时候,在finetune QA任务的时候只使用语言模型得到的编码器编码一次所有的z和x并构建索引。

 

额外策略

在研究过程中作者发现了一些能让模型更好训练的策略。

  • 只训练真正需要知识的词(通常是实体和日期)来训练MLM

  • 在topk文档外添加一个虚拟的null document

  • 避免让x出现在z中(因为x被mask过,如果它来源于z,那答案就暴露了!)

  • 避免冷启动的retriever太渣导致的恶性循环,他们用了一个以ICT作为任务的模型来初始化retriever

结果对比

这篇论文的对手主要是原来sparse retriever+神经阅读理解模型的组合,例如大名鼎鼎的DrQA。所谓sparse retriever就是指用例如TFIDF之类的特征进行检索的模型。还有一些跟本文很像的neural retriever+neural reader的组合。其中提到了一个ORQA,跟这篇非常像,只是这篇增加了预训练的步骤。最后是一些生成式模型,例如finetune后的T5(可怕!)

在Natural Questions-Open(NQ)、Web Questions(WQ)和Curated Trec(CT)三个数据集上的结果如下

REALM: Retrieval-Augmented Language Model Pre Training 解读_第5张图片

总之一句话,非常牛逼!而且这里模型只取了top 5文档,其他模型可能取了20-80篇,还是打不过他。注意到ours的括号内有两个数据,Z是知识库,很好理解,X是指预训练用的语料。而且通过文章的Ablation Analysis部分可以知道预训练其实是非常关键的一个步骤,对performance的贡献非常大。

 

后记

我感觉这篇论文和他提到的ORQA还是很厉害的,知识嵌入也从去年的实体粒度的嵌入发展到了如今句子、篇章级别的嵌入。试想一下,这项技术发展起来之后,如今基于词的Sparse搜索引擎应该很快会发展成对NN更友好的Dense搜索引擎,所有的内容也许都会被映射到一个向量空间。各路神经网络将在这个向量空间尽情驰骋~莫非得encoder者得天下?!

 

原文首发于公众号:朴素人工智能,欢迎关注

REALM: Retrieval-Augmented Language Model Pre Training 解读_第6张图片

你可能感兴趣的:(自然语言处理,深度学习,机器学习,人工智能)