MS MARCO Passage Ranking Leaderboard —— RocketQAv2

    本文对 RocketQA 的第二代版本 RocketQAv2 进行解读,原文地址请点击此处。

1. 背景介绍

    按照retrieve-then-rerank的方式,段落检索中的密集检索器和段落重排序器共同对最终性能做出贡献。尽管这两个模块在推理阶段作为管道工作,但发现联合训练它们是有用的。例如,具有双编码器的检索器可以通过从具有更强大的交叉编码器架构的重排序器中提取知识来改进,同时重排序器也可以通过检索器生成的训练实例进行改进。因此,越来越多的人关注retrier和re-ranker的联合训练,以实现相互改进。

    然而,这两个模块通常以不同的方式进行优化,因此无法轻易实现联合学习。检索器通常通过对 in-batch negatives 进行采样来训练,以最大化相关段落的概率并最小化采样负样本的概率,其中模型是通过考虑正负样本的整个列表来学习的(称为 listwise 方法)。而重排器通常是使用 pointwise 或者 pairwise 的方式进行训练。

    本篇论文提出了一种统一检索器和重排器的训练方法—— Dynamic Listwise Distillation,统一使用 listwise 联合训练它们,并且使用数据增强方法为 listwise 训练提供多样化和高质量的训练实例。

2. 任务定义

    给定一个查询 q q q,密集段落检索的目的是从 M M M 个文本段落的大集合中检索 k k k 个最相关的段落。先前的工作广泛采用双编码器(Dual Encoder,DE)架构,其中两个独立的密集编码器 E P ( ⋅ ) E_{P}(·) EP() E Q ( ⋅ ) E_{Q}(·) EQ() 用于将段落和查询分别映射到 d 维实值向量(也称为embedding),查询 q q q 和段落 p p p 之间的相似性使用点积定义:
s d e ( q , p ) = E Q ( q ) ⊤ ⋅ E P ( p ) s_{\mathrm{de}}(q, p)=E_{Q}(q)^{\top} \cdot E_{P}(p) sde(q,p)=EQ(q)EP(p)

    给定一个由检索器检索到的候选段落列表,段落重新排序的目的是通过重排器进一步改进检索结果,重排器估计一个相关性分数 s ( q , p ) s(q, p) s(q,p),衡量查询 q q q 和一个候选段落 p p p 的相关性水平。在重排器的实现中,基于 PLM 的交叉编码器(Cross Encoder,CE)通常可以实现卓越的性能。它可以更好地捕捉段落和查询之间的语义交互,但比双编码器需要更多的计算量。

3. Dynamic Listwise Distillation

    给定查询集 Q \mathcal{Q} Q 中的查询 q q q 和与之相应的候选段落列表 P q = { p q , i } 1 ≤ i ≤ m \mathcal{P}_{q}=\left\{p_{q, i}\right\}_{1 \leq i \leq m} Pq={pq,i}1im ,我们可以分别从 DE 和 CE 中获得查询 q q q P q \mathcal{P}_{q} Pq 中的段落的相关性分数 S d e ( q ) = { s de ⁡ ( q , p ) } p ∈ P q S_{\mathrm{de}}(q)=\left\{s_{\operatorname{de}}(q, p)\right\}_{p \in \mathcal{P}_{q}} Sde(q)={sde(q,p)}pPq S c e ( q ) = { s c e ( q , p ) } p ∈ P q S_{\mathrm{ce}}(q)=\left\{s_{\mathrm{ce}}(q, p)\right\}_{p \in \mathcal{P}_{q}} Sce(q)={sce(q,p)}pPq。然后,我们以 listwise 方式对它们进行归一化,以获得候选段落的相关性分布:
s ~ d e ( q , p ) = e s d e ( q , p ) ∑ p ′ ∈ P q e s d e ( q , p ′ ) s ~ c e ( q , p ) = e s c e ( q , p ) ∑ p ′ ∈ P q e s c e ( q , p ′ ) \begin{array}{l} \tilde{\boldsymbol{s}}_{\mathrm{de}}(q, p)=\frac{e^{s_{\mathrm{de}}(q, p)}}{\sum_{p^{\prime} \in \mathcal{P}_{q}} e^{s_{\mathrm{de}}\left(q, p^{\prime}\right)}} \\\\ \tilde{\boldsymbol{s}}_{\mathrm{ce}}(q, p)=\frac{e^{s_{\mathrm{ce}}(q, p)}}{\sum_{p^{\prime} \in \mathcal{P}_{q}} e^{s_{\mathrm{ce}}\left(q, p^{\prime}\right)}} \end{array} s~de(q,p)=pPqesde(q,p)esde(q,p)s~ce(q,p)=pPqesce(q,p)esce(q,p)

    主要思想是自适应地减小 retriever 和 re-ranker 两个分布之间的差异,论文通过最小化两个分布的 KL 散度实现这一目标:

L K L = ∑ q ∈ Q , p ∈ P q s ~ d e ( q , p ) ⋅ log ⁡ s ~ d e ( q , p ) s ~ c e ( q , p ) \mathcal{L}_{\mathrm{KL}}=\sum_{q \in \mathcal{Q}, p \in \mathcal{P}_{q}} \tilde{s}_{\mathrm{de}}(q, p) \cdot \log \frac{\tilde{s}_{\mathrm{de}}(q, p)}{\widetilde{s}_{\mathrm{ce}}(q, p)} LKL=qQ,pPqs~de(q,p)logs ce(q,p)s~de(q,p)

    此外,对于CE,还使用了交叉熵损失函数,旨在最大化列表中正样本的概率:

L sup  = − 1 N ∑ q ∈ Q , p + log ⁡ e s c e ( q , p + ) e s c e ( q , p + ) + ∑ p − e s c e ( q , p − ) \mathcal{L}_{\text {sup }}=-\frac{1}{N} \sum_{q \in \mathcal{Q}, p^{+}} \log \frac{e^{s_{\mathrm{ce}}\left(q, p^{+}\right)}}{e^{s_{\mathrm{ce}}\left(q, p^{+}\right)}+\sum_{p^{-}} e^{s_{\mathrm{ce}}\left(q, p^{-}\right)}} Lsup =N1qQ,p+logesce(q,p+)+pesce(q,p)esce(q,p+)

    其中 N N N 是训练实例的数量, p + p^{+} p+ p − p^{-} p 分别表示 P q \mathcal{P}_{q} Pq 中的正样本和负样本。将 KL 散度损失和交叉熵损失结合起来,得到最终的损失函数:

L final  = L K L + L sup  . \mathcal{L}_{\text {final }}=\mathcal{L}_{\mathrm{KL}}+\mathcal{L}_{\text {sup }} . Lfinal =LKL+Lsup .

MS MARCO Passage Ranking Leaderboard —— RocketQAv2_第1张图片
    上图显示了 Dynamic Listwise Distillation 的流程,很明显可以发现,其实就是将检索器的 listwise 训练方法转移给了重排器,然后再用重排器生成的分布去训练检索器,不同于 rocketQA 使用重排器生成的硬伪标签去训练检索器,rocketQAv2 利用软标签(即估计的相关性分布)进行相关性蒸馏。

4. Hybrid Data Augmentation

    为了执行 Dynamic Listwise Distillation,我们需要为查询 q q q 生成候选段落列表 P q \mathcal{P}_{q} Pq。回顾一下 rocketQA,它的 P q \mathcal{P}_{q} Pq 由一条正样本,同一 batch 内其他查询的正样本作为负样本(in-batch negatives sample) 和 CE 去噪后的硬负样本构成。rocketQAv2 不再使用 in-batch negatives sample,只通过结合随机采样和去噪采样硬负样本的方法来构建 P q \mathcal{P}_{q} Pq

MS MARCO Passage Ranking Leaderboard —— RocketQAv2_第2张图片
    如图所示,首先,利用 RocketQA DE 从语料库中检索前 n 个段落。对于随机采样,直接从检索到的段落中随机抽取未去噪的硬负样本。对于去噪采样,则先利用 RocketQA CE 来去除置信度分数较低的预测负样本,然后再随机采样。同时,还将一些 RocketQA CE 预测置信度分数较高的预测正样本标记为正例,以增加正例的数量以及减轻假阴性的影响。

5. 实验结果分析

MS MARCO Passage Ranking Leaderboard —— RocketQAv2_第3张图片
    不同段落检索方法的结果如表 2 所示,第一部分的 BM25 是最传统的稀疏检索算法,第二部分是使用深度学习模型加强后的稀疏检索算法,第三部分是密集检索算法。可以观察到:(1)RocketQAv2 检索器和 PAIR 大大优于其他基线。 PAIR 是 RocketQAv2 的同时代工作(同一批作者),它通过对域外数据进行预训练来获得改进,而 RocketQAv2 没有额外的数据。(2)RocketQAv2 优于 DPR-E(使用ERNIE 实现的 DPR),表明 PLM 不是获取提升的因素。(3)在稀疏检索器中,我们发现 COIL 优于其他方法,这似乎是一个强大的稀疏基线,可以在两个数据集上提供显著的性能。我们还观察到,稀疏检索器的总体表现比密集检索器差,这表明密集检索方法的有效性。

MS MARCO Passage Ranking Leaderboard —— RocketQAv2_第4张图片
    不同段落重排序方法的结果如表 3 所示,图中 RocketQA 的重排器使用 pairwise 的方式进行训练,RocketQAv2 的重排器使用 listwise 的方式进行训练。可以看到,在使用相同的检索器 RocketQA retriever 时,RocketQAv2 re-ranker 的MRR@10要比 RocketQA re-ranker 高接近一个点,证明了 listwise 训练重排器确实有效。此外,将检索器更改为 RocketQAv2 retriever 时,RocketQAv2 re-ranker 的MRR@10只提升了0.1,尽管 RocketQAv2 retriever 的效果比 RocketQA retriever 好了接近2个点,这说明MRR的提升很大程度上是由 re-ranker 贡献的。

MS MARCO Passage Ranking Leaderboard —— RocketQAv2_第5张图片

    图 4 展示了每个查询所用的训练实例数量对检索器和重排器的影响。对于每个查询,我们采样一个正实例,而实例列表 P q \mathcal{P}_{q} Pq 中的其余实例是硬负实例。因此,实例数的影响应该等同于硬负实例数的影响。可以看到,随着硬负实例的增加,检索器的重排器的性能也随之提高。此外,论文也尝试将 in-batch negative sample 加入训练,但没有看到明显提升。
    总的来说,RocketQAv2 的训练方式比 RocketQA 优雅了许多,效果也更好,但是由于抛弃了 in-batch negative sample,每个 query 都要带额外的负样本(论文中报告的最好结果是使用了384个负样本),在保证batch有足够大的情况下,就不可避免地带来了显存的增加(论文中使用了 32个 V100 进行训练)。综上所述,RocketQAv2 是土豪玩家的玩法,平民还是老老实实用 RocketQA 吧。
在这里插入图片描述

你可能感兴趣的:(QA,自然语言处理,自然语言处理)