RocketQA学习

RocketQA学习

paper

本文为2020年百度提出的一种用于针对对偶式检索问答模型的工程优化RocketQA。在本文中提出了三种优化方法,来提高在实际中对偶式模型的效果。包括:跨批次负采样,困难样本去噪以及数据增强训练。

在传统的检索问答模型中,通常使用tfidf,或者bm25这种稀疏向量进行候选项检索,再使用交互模型进行最终结果排序。针对传统的检索问答模型中的检索和排序问题,通过使用对偶式的深度学习模型来实现替换。原文认为通过这样的端到端方式实现检索和排序,降低了系统的复杂性,同时可以让模型基于用户的实时反馈进行训练,加速模型迭代。

“端到端问答”摒弃了传统系统中繁杂的构件,系统复杂性大大降低,并且其中每个模块(段落检索和答案定位)都是可学习的,这样的设计能够让整个系统实现端到端训练。这意味着问答系统可以基于用户实时的反馈实现在线训练,而不是只在封闭的数据集上闭门造车。这是智能问答技术的发展趋势,可能会引发问答系统的新一代技术变革。

对于使用对偶式检索模型,在实际训练中存在以下问题:

  1. 训练场景和预测场景中样本数量存在较大差异

在开放域问答的应用场景中,模型需要从大规模的候选集合中找出问题的答案。但是按照批次内负采样的方法训练时,每个问题的候选段落个数与批次大小相同。受到单 GPU 显存大小的限制,训练过程中见到的候选段落远小于预测时的候选段落,从而导致模型即使在训练时表现良好,在实际应用当中却差强人意。为了降低这种差异,以往的工作会试图设计一种使用困难样本进行训练的机制,然而因为接下来的第二个问题导致效果并不是很好。

  1. 数据集中存在大量漏标注的正确答案

开放域问答场景下候选段落的数量往往非常大,标注出问题的全部正确答案几乎是不可能的。在 MSMARCO 数据集中,候选段落的总数为 880 万,但每个问题平均只标注了 1.1 个正确答案。研究人员发现,在使用对偶模型检索出的首条结果中,70% 的错误结果其实是漏标的正确答案。这种情况下,构造训练数据中的强负例时很容易引入假负例(false negative),给模型训练带来负面影响。

  1. 相对于开放域全集,人工标注训练数据的规模小、成本大

尽管目前已有较多大规模的问答数据集,但是相较于开放域的用户问题来说,仍然是冰山一角。有限的标注数据集无法覆盖到全面的领域和类型,导致模型泛化性差。想要增大标注数据的规模和质量,需要很高的人工成本。

模型介绍

RocketQA模型结构与传统的对偶式模型一致,由两个具有相同模型结构的独立模型(相同的模型初始化,但权重在训练时并不保持一致)分别对问题和段落进行编码,通过这种方式可以将段落编码和问题编码分离开,使用预先计算的方式得到所有段落的编码。两个模型可以使用相同的预训练模型如bert来进行初始化,然后使用第一个token的表达作为编码输出。

训练方式类似dssm,通过最大以下损失函数来使question-positive passage的相似度高于question-negative passage
L ( q i , { p i , j − } j = 1 m ) = − l o g e s i m ( q i , p i + ) e s i m ( q i , p i + ) + ∑ j = 1 m e s i m ( q i , p i , j − ) s i m ( q , p ) = E q ( q ) T ⋅ E p ( p ) L(q_i,\{p^-_{i,j}\}_{j=1}^m)=-log\frac{e^{sim(q_i,p_i^+)}}{e^{sim(q_i,p_i^+)}+\sum^m_{j=1}e^{sim(q_i,p^-_{i,j})}}\\ sim(q,p)=E_q(q)^T\cdot E_p(p) L(qi,{pi,j}j=1m)=logesim(qi,pi+)+j=1mesim(qi,pi,j)esim(qi,pi+)sim(q,p)=Eq(q)TEp(p)
在推断时通常需要面对大量的候选段落,因此常采用的加速方法是Asymmetric lsh (alsh) for sublinear time maximum inner product search,在本文中作者使用的是faiss。基本思想就是离线构造稠密向量的索引,在查询时,给定一个问题编码,在次线性时间内找到top-k个最相似的段落。

优化方法

在前文中,提到了对偶式检索模型遇到的问题,对此百度提出如下解决方案。

  1. 跨批次负采样(cross-batch negatives)

    采用传统的批次内负采样方法训练时,每个问题的候选段落个数与批次大小相同。为了进一步增加训练过程中候选段落的数量,百度提出了跨批次负采样方法(如图 3 所示)。该方法能够在使用多 GPU 并行训练时,将其它 GPU 批次内的全部段落作为当前问题的负样本。这样可以直接复用各个 GPU 上已经计算好的段落表示,不额外增加计算量;同时基于飞桨分布式训练扩展工具包 FleetX 的 all-gather 算子实现,只需要使用很少的通信量和内存开销,就达到了增加每个问题候选段落的目的。随着 GPU 个数的增加,每个问题的候选段落个数线性增加,训练场景中的任务难度也更加接近真实场景。百度在 MSMARCO 数据集上进行了实验,在使用跨批次负采样后,随着训练时候选段落数量增加,模型的效果稳步提升(如图 4 所示)。

    RocketQA学习_第1张图片

    图 3 批次内负采样(上)和跨批次负采样(下)的对比

    RocketQA学习_第2张图片

    图 4 MSMARCO 数据集中,训练阶段候选段落的个数对模型效果的影响

  2. 去噪的强负例采样(denoised hard negative sampling)

在对偶模型的训练中,适当增加训练数据中的强负例的难度,有助于提升模型效果。一般的做法是,从一个排序的候选段落中进行采样,越靠前的负例对模型来说难度越大。但是由于难以避免的漏标注情况,直接采样很大概率会引入假负例。为了解决这一问题,百度使用交互模型(cross-encoder)的打分作为监督信息进行去噪。在选择强负例时,避开交互模型给出高置信度的样例。相较于对偶模型,交互模型具有结构上的优势,能够编码更多的交互信息,从而给出可靠的监督信号,帮助对偶模型选取更可靠的强负例。如表 1 的第三行和第四行所示,去噪的强负例采样可以显著提升模型效果。

  1. 数据增强(data augmentation)

交互模型可以过滤强负例中的噪声,也可以用来选取未标注的正确答案。因此,当引入大量无标注的问题时,便可以利用交互模型以极低的成本得到大量弱监督数据,进一步增强对偶模型的能力。在 MSMARCO 数据集的实验中,百度引入了 Yahoo!Answers 和 ORCAS 数据集中的 150 万未标注问题,用交互模型在对偶模型检索出的候选段落上进行打分,并根据置信度选取正负样本。如下表的第四行和第五行所示,通过这种方式,对偶模型的效果得到进一步提升。
RocketQA学习_第3张图片

训练流程

C表示收集到段落集合, Q L Q_L QL表示与在C中的段落有对应标签的问题集合, Q U Q_U QU表示没有对应段落标签的问题集合。 D L D_L DL表示包含C和 Q L Q_L QL的数据集, D U D_U DU表示包含C和 Q U Q_U QU的数据集。

  1. 使用跨批次负采样方法训练一个对偶检索模型 M D ( 0 ) M_D^{(0)} MD(0)
  2. 使用 D L D_L DL训练一个交互模型 M C M_C MC。此处的负采样使用的是对偶模型 M D ( 0 ) M_D^{(0)} MD(0)对每个问题q从C中找出的top-k相似段落(排除正例段落)。
  3. 训练一个对偶检索模型 M D ( 1 ) M_D^{(1)} MD(1)通过引入去噪的强负例采样。此处的负采样使用的是对偶模型 M D ( 0 ) M_D^{(0)} MD(0)对每个问题q从C中找出的top-k相似段落(排除正例段落),然后将相似段落使用模型 M C M_C MC进行预测,如果有高概率为正例,则将其移除。(即降噪)
  4. 通过模型 M D ( 1 ) M_D^{(1)} MD(1)来为 Q U Q_U QU中的问题找到 C C C中相似的top-k段落,使用模型 M C M_C MC来为top-k的段落打标,然后使用数据 D L D_L DL和数据增强后的 D U D_U DU来训练对偶检索模型 M D ( 2 ) M_D^{(2)} MD(2).

注意:跨批次采样策略使用在每一个对偶检索模型的训练中。

注意:交互模型在step 3和step 4中的使用目的不同,step 3是为了降噪,step 4是为了做数据增强。

参考

端到端问答新突破:百度提出RocketQA,登顶MSMARCO榜首

你可能感兴趣的:(文本检索,自然语言处理)