RAG小结

RAG

RAG做知识问答,我们所使用的语料、索引建立方式、训练方式,大致的模型参数,以及满足以上我们对GPU运算资源的需求。

  • 需求总结:
    • 以下表格列举我们实验配置和论文实验配置;
    • 如果要使用base甚至large版本的预训练模型和batch128,最好是单服务器节点有3块16G+的gpu;没有的情况下,可以利用多节点多GPU并行,需要配置;但单块gpu能跑的batch还是很有限;
配置 Retriever Generator RAG
模型 DPR
(Albert-small*2)
mT5-small DPR+mT5
(fix-ctx)
参数 5M*2 80M 85M
训练语料 问诊(11W)+母婴(2K)
测试1W
问诊(11W)+母婴(2K)
测试1W
问诊(11W)+母婴(2K)
测试1W
训练方式 B-24/S-256/40E B-16/S-256/30E B-16/S-256/30E
索引方式 Faiss(12W/1hour)
实际占用GPU 3gpus*9G/gpu
(office0-1080ti)
2gpus*10G/gpu
(office2-2080ti)
2gpus*11G/gpu
(office2-2080ti)
以下为论文使用配置
模型 DPR
(Bert-base-uncased*2)
Bart-large DPR+Bart
(fix-ctx)
参数 120M*2 400M 520M
训练语料 Eng-QA(16W)+Wiki(2KW ) Eng-S2S预训练 QA(16W)
训练方式 B-128/S-100/40E B-8000 B-128/S-100/40E(约)
索引方式 Faiss(21M/8.5hour)
实际占用GPU 8gpus*32G/gpu 不少于DPR(约) 不少于DPR(约)

注1:fix-ctx指固定context encoder,不做更新
注2:M-百万,W-万,K-千;h指hour,B指Batch,S指Seq length,E指Epoch;Eng指English;S2S指Seq2Seq;
注3:我们的其他可用语料近600W–百度百科(146W) & dureader(22W) & web问答(400W) & cmrc2018(1.3W);
注4:“约” 指的是论文没有明确指出配置具体情况,表格给出的相关数据为综合推测;
注5:DPR论文提到21M passages建立Faiss索引只需8.5h,使用E5-2698 [email protected] CPU和512G的内存;使用of2 CPU和96G内存实测faiss建索引,速度约12W/1h,300W/85h(非线性);查看实测用的Faiss索引代码和DPR论文索引代码实现方式,基本一致;估计是硬件差别;
注6:数据问题–DPR论文训练数据answer有对应的源context/passage,我们的数据没有(除了2K母婴QA),目前是利用answer同时做context,问题应该不大;

  • RAG相关

    • 流程如图RAG小结_第1张图片
    • 基本过程:RAG由Retriever和Generator组成,输入编码后(对QA来说即question)传入Retriever,Retriever检索与question相关的多个内容(称为contexts),并将contexts与question结合后,作为新的输入发送给Generator,Generator接受输入并生成结果(对QA即生成答案)。
  • 组件细节

    • Retriever
      • 来自另一篇论文DPR。DPR使用Bi-Encoder,即双编码器,分别编码question和context。原始paper使用的两个编码器为 Bert-base-uncased;
      • Retriever训练使用In-batch negative training,大致是 batch内每个样本(即question)对应的positive context均为其他样本的negative;每个样本另外利用BM25寻找最相似的context作为hard negative,以此增强模型学习能力;
      • Inference阶段,Retriever将输入的question编码为q_vector,再利用q_vector和事先做好索引的passages(检索语料库,即context)进行相似度计算,获得相似度top k的contexts,与question拼接,结果作为新的输入传递给generator,由generator生成答案;
    • Generator
      • RAG论文使用的generator为预训练的Bart模型,Bart是一个预训练的seq2seq(S2S)模型,将很多任务当做S2S模式来训练;MBart/MT5为基于Bart/T5的多语言预训练模型;
      • generator经过预训练,已存有一定的知识,再结合检索的contexts,可以更好的生成答案;
  • 训练过程

    • 固定retriever的context encoder,训练question encoder及generator;

你可能感兴趣的:(技术问题,#,算法,#,深度学习,人工智能,深度学习)