先上论文链接
RocketQA: An Optimized Training Approach to Dense Passage Retrieval for Open-Domain Question Answering
目前的问答系统一般都是retrieval-reader的架构,即先检索相关文档,再进行阅读。检索器作为召回组件,对于阅读器的表现至关重要,但目前的检索器都面临着一个问题,训练精度与推理精度差距大。此外,训练数据集中存在着很多无标签的样本,也会影响模型的精度。
由于交叉编码器通过深度互动来探究问题与段落的相似性,因此它在训练完成之后,在无标签的样本上区分正负例的能力更强,更加具有鲁棒性,因此RocketQA在区分无标签数据的正负性时使用的是这个架构。
由于训练与推理中的负例样本差距大导致精度可能出现差异,RocketQA提出了跨批次负采样在训练中增加负例样本。
原先的批次内采样是在单个GPU的一个批次内,对于 B B B个问题段落对,每个问题都有与之对应的最相关的段落,那么剩下的 B − 1 B-1 B−1个段落作为问题的负例段落。在进行多GPU并行训练时,这个方法就无法充分的使用别的GPU的样本数据了。跨批次负采样的方法是,将每个GPU中的段落在所有GPU中共享,这样如果有 A A A个GPU的话,那么每个问题就能够拥有 × − 1 \times−1 A×B−1个负例样本。
上面的方法可以有效增加训练中的负例样本数量,但是大部分的负例样本对于模型来说很好判断。强负样本对于模型是非常重要的,它能够增加模型的鲁棒性,为了获取强负样本,一般的做法是在返回的top-K个段落中,根据标签挑选出负例段落,这样的段落,模型将其视为与答案最相关的K个段落之一,因此我们可以把它看作强负例样本,但是,由于漏标的存在,这样的段落也有可能是没有被标记的正例样本。RocketQA的做法是利用训练好的交叉编码器,对于top-K个段落,首先去掉有标签的正例样本,然后将剩下的段落交给交叉编码器进行判断,判断为负例的段落,模型才将其视为强负例样本。
这个方法旨在缓解训练数据有限的情况,交叉编码器是使用实验的数据集进行训练的,为了防止模型在检索中的作弊行为,文章引入了大量的无标签的数据,首先交给交叉编码器进行标记,然后将其扩充到原始数据集中,为了保证标记的质量,这里会设置一个阈值,只有高阈值的正例和负例才能够被加入到原始数据集中。这里的数据增强方式可以被视为是一种知识蒸馏,交叉编码器是教师模型,接下来被训练的对偶编码器是学生模型。
定义
C表示段落集。 Q L Q_L QL是在 C C C中有标签的段落对应的问题, _ QU是在 C中有标签的段落对应的问题, _ DL是由 C和 _ QL组成的数据集, _ DU是由 C和 _ QU组成的数据集。
步骤1
通过跨批次负采样方法在 _ DL上训练一个对偶编码器 ( 0 ) _^{(0)} MD(0)。
步骤2
在 _ DL上训练一个交叉编码器 _ MC。正例样本使用原始数据集中已经标定的,负例样本使用从 ( 0 ) _^{(0)} MD(0)中返回的top-K个文档中除去正例样本之后,剩下的样本进行随机采样。
步骤3
通过去噪的强负例方法在 _ DL上训练一个双编码器 ( 1 ) _^{(1)} MD(1)。强负例样本由 ( 0 ) _^{(0)} MD(0)返回,并且紧接着送入 _ MC,再从 _ MC的输出中挑选置信度高的负例样本作为去噪后的负例样本。
步骤4
对于每个问题,先通过 ( 1 ) _^{(1)} MD(1)返回top-K个段落,再通过 _ MC来对段落进行标记,最后再利用手动标记的数据与标记的数据一起训练一个对偶编码器 ( 2 ) _^{(2)} MD(2)。
跨批次负采样策略在这四个步骤中都会使用到。