Knn-LM
->REALM
->DPR
->RAG
->FID
->COG
->GenRead
->REPLUG
->Adaptive retrieval
kNN-LM
, an approach that extends a pre-trained LM by linearly interpolating its next word distribution with a k-nearest neighbors (kNN) model.Datastore: ( K , V \mathcal{K,V} K,V), the set of all key-value pairs constructed from all the training examples in D D D
Inference: Interpolate the nearest neighbor distribution p k N N p_{kNN} pkNN with the model distribution p L M p_{LM} pLM using a tuned parameter λ \lambda λ to produce the final k N N − L M kNN-LM kNN−LM distribution(input context x x x)
p L M ( y ∣ x ) p_{LM}(y|x) pLM(y∣x): given the input context x x x the model generates the output distribution over next words p L M ( y ∣ x ) p_{LM}(y|x) pLM(y∣x)
p k N N ( y ∣ x ) p_{kNN}(y|x) pkNN(y∣x): a distribution over k-nearest neighbors
Performance on WIKITEXT-03
Can retrieving nearest neighbors from data be a substitute for training on it?
WIKI-100M
and retrieving from WIKI-100B
is better that training on WIKI-3B
kNN-LM
over a large corpus.How the amount of data used for kNN retrieval affects performance?
WIKI-3B
and preforming on BOOKS
Key function
Number of neighbors per query(Figure 4) and interpolation parameter(Figure 5)
kNN-LM
is most helpful typically contain rare patterns预训练语言模型的缺点
之前工作的局限
retrieve relevant documents and extract an answer from the docs
and extends it to language model pre-training本文提出REALM
,一个retrieve-then-predict
方法
Methods compared with:
For both pre-training and fine-tuning, REALM
takes some input x and learns a distribution p(y | x) over possible outputs y.
pre-training: masked language modeling
fine-tuning: Open-QA
two-stages:
pretraining: use MLM loss
Open-QA fine-tuning: assume that the answer y y y can be found as a contiguous sequence of tokens in some document z z z
B E R T S T A R T ( s ) BERT_{START(s)} BERTSTART(s) and B E R T E N D ( s ) BERT_{END(s)} BERTEND(s) denote the Transformer output vectors corresponding to the start and end tokens of span s, respectively
正确的分数大,不需要保证错误的分数小吗?
do not update E m b e d d o c Embed_{doc} Embeddoc for simplicity
Pretraining: 8 candidate documents, two choices of corpus:(1) Wikipedia (2)CC-News
Finetuning: consider top-5 candidates
can we train a better dense embedding model using only pairs of questions and passages (or answers), without additional pretraining
Propose DPR, a two-stage framework:
Encoders: two independent BERT
Training:
goal: create a vector space such that relevant pairs of questions and passages will have smaller distance
source documents: Wikipedia dump from Dec. 20, 2018(100 words as passages, title + passage)
QA datasets: Natural Question
; TriviaQA
; WebQuestion
; CuratedTREC
; SQuAD v1.1
NQ, TriviaQA, SQuAD
TREC, WQ
Retrieval
End-to-end QA
Besides the retriever, our QA system consists of a neural reader extracts an answer span from the passages
BERT
to predict the start_token
and the end_token
1.预训练模型存储知识的能力很强,但访问和精准操控知识的能力还受限,所以在knowledge-intensive任务上不如task-specific架构。
2.parametric memory with non-parametric (i.e., retrieval-based) memories结合可以解决一些问题
3.REALM
和 ORQA
利用了这种形式(基于masked language model),但是只探索了 open-domain extractive question answering
因此,本文将这种方式扩展到NLP的主力seq2seq models上
RAG-Sequence
和 RAG-Token
uses the same retrieved document to generate the complete sequence.
use a different latent document for each target token.
We use a pre-trained bi-encoder from DPR to initialize our retriever and to build the document index
use BART-large
and simply concatenate the input x x x and the retrieved content z z z
jointly train the retriever and generator components without any direct supervision on what document should be retrieved.
RAG-Token
:按beam生成,每个token的概率都知道
RAG-Sequence
: 对每个文档都生成一个输出 y y y,构成集合 Y Y Y。有些文档生成的 y y y, 另一些文档未必能生成。我们对所有的文档都算一下这样的 y y y 的概率,然后一个 y y y 的概率就能写成 ∑ z ∈ t o p − k p ( z ∣ x ) p ( y ∣ x , z ) \sum_{z\in top-k}p(z|x)p(y|x,z) ∑z∈top−kp(z∣x)p(y∣x,z)。 这叫做Thorough Decoding
Fast Decoding
在四种knowledge-intensive任务上测试RAG。
open-domain QA
Abstractive Question Answering(MSMARCO)
Jeopardy QG(Jeopardy)
Fact Verification(FVR3, FVR2)
之前方法的缺陷:
DPR
and REALM
)
Propose retrieval + generation.
two steps:
Reformulate text generation by copying text segments from existing text collections
改进:动态学习phrase table,对里面的内容进行增删改查,或者将fixed phrase转成dynamic phrase
At each time step, a suitable phrase is selected and appended to the current prefix accordingly
For a document D i D^i Di, a phrase k = D s : e i k = D^i_{s:e} k=Ds:ei of length e − s + 1 can be extracted, where s s s and e e e mark the start and end positions of the phrase in the document, respectively.
denote all the phrases in the source text collection as P \mathcal{P} P–> { ( k , p k ) ∣ k ∈ P } \{(k,p_k)|k \in \mathcal{P}\} {(k,pk)∣k∈P}
to support the scenarios where no suitable phrases are available, we also add the context-independent token embeddings ( w , v w ) ∣ w ∈ V {(w, v_w)|w ∈ V } (w,vw)∣w∈V in standard LMs to the phrase table
The model consists of three major components:
a prefix encoder that maps prefixes to fixed-sized representations
a context-dependent phrase encoder that computes the vector representations of the phrases in the source text collection
For a document D = D 1 , . . . , D m D = D_1, . . . , D_m D=D1,...,Dm of length m:
first apply a deep bidirectional Transformer(BERT-base-cased) to obtain contextualized token representations D m × d t D^{m \times d_t} Dm×dt
apply two MLPs models, M L P s t a r t MLP_{start} MLPstart and M L P e n d MLP_{end} MLPend, to convert D D D into start and end token representations respectively:
for each phrase D s : e D_{s:e} Ds:e, use the concatenation of the corresponding start and end vectors as the phrase representation
a set of context-independent token embeddings similar to the one used in standard neural language models
为什么用GPT-2生成的表示,与BERT生成的表示算匹配,二者在一个表达空间内吗?
a document D has been split into n phrases D = p 1 , . . . , p n D = p_1, . . . , p_n D=p1,...,pn
the training loss for next-phrase predictions(next-phrase prediction)
to retain the capability of token-level generation, we also train COG with the standard token-level autoregressive loss(next-token prediction)
The training loss is the sum of these two losses.
Inference Speed
kNN-LM
is much higher than Transformer, and COG
Case Study
COG
allows a single model to be specialized in different domains, by simply switching the source text collectionIdea
Levenshtein Transformer
: 这个模型在生成时,可以对生成的结果进行增删改(NeurIPS 2019
)
ICLR 2023: 8 8 8 10
Three drawbacks of retrieve-then-read pipeline
Propose to leverage LLMs to directly generate contextual documents for a given question,two advantages
generated contextual documents contain the correct answer more often than the top retrieved documents
our approach significantly outperforms directly generating answers from large language models despite not incorporating any new external information
Two steps:
first prompts a LLM to generate contextual documents with respect to a given query
reads the generated documents to predict the final answer(a large model like InstructGPT
for zero-shot or a smaller model like FID
for finetuning)
Zero setting:
InstructGPT
) to generate documents based on the given question with greedy decoding strategySupervised setting:
Explore how the generated documents from large language models can benefit the supervised setting.
FiD
to peruse the generated documents under the supervised setting(finetune the reader)Clustering-based prompts:
pairs 都是question-independent,对一个question中的所有question来说都是相同的。对不同question来说,生成的document可能都是与question某个特定方面相关的,因为prompt里面的关系是相同的
Zero-shot
Supervised setting
InstructGPT + FiD
(FiD
is fine-tuned on the training split of target datasets)
Other tasks
Case Study
REPLUG
,一个将语言模型视为黑盒检索增强的语言模型架构。在REPLUG
中,仅将检索得到的文档拼接到原有输入前面即可,不需要像以前一样更新语言模型参数。该架构中可以通过更新检索器进一步提升性能。FAISS
来快速找到top-k文档softmax
的结果REPLUG LSR
可以看做 REPLUG
的一个增强版本。在REPLUG
中,我们使用的检索器可能不够适配语言模型,因此这里利用语言模型本身反馈的监督信号,来调整REPLUG
中的检索器。
核心思想:our approach can be seen as adjusting the probabilities of the retrieved documents to match the probabilities of the output sequence perplexities of the language model
ground truth
序列的概率更大,那么我们认为模型的效果越好这部分介绍如何计算检索文档概率分布与输出序列概率分布
给定输入 x x x,我们检索回来概率最大的top-k个文档,为 D ′ ⊂ D D^{'} \subset D D′⊂D,文档 d d d的检索概率(likelihood)为
P R ( d ∣ x ) = e s ( d , x ) / γ ∑ d ∈ D ′ e s ( d , x ) / γ P_R(d \mid x)=\frac{e^{s(d, x) / \gamma}}{\sum_{d \in \mathcal{D}^{\prime}} e^{s(d, x) / \gamma}} PR(d∣x)=∑d∈D′es(d,x)/γes(d,x)/γ
γ \gamma γ是用来控制 softmax
温度的超参
按理应该在整个 D D D 上进行,但是那样计算量太大,因此在 D ′ D^{'} D′ 上近似计算
将语言模型用来评估每个文档对语言模型困惑度的提升程度,首先计算 P L M ( y ∣ d , x ) P_{LM}(y|d,x) PLM(y∣d,x),这是给定 x x x 和文档 d d d 时,ground truth
y y y 的生成概率。如果这个概率越大,则说明当前文档对困惑度的提升程度越大。然后计算分布:
Q ( d ∣ x , y ) = e P L M ( y ∣ d , x ) / β ∑ d ∈ D ′ e P L M ( y ∣ d , x ) / β Q(d \mid x, y)=\frac{e^{P_{L M}(y \mid d, x) / \beta}}{\sum_{d \in \mathcal{D}^{\prime}} e^{P_{L M}(y \mid d, x) / \beta}} Q(d∣x,y)=∑d∈D′ePLM(y∣d,x)/βePLM(y∣d,x)/β
有了两个分布之后,用loss function
对二者进行匹配
在给定 x x x 和 y y y 时,计算检索概率分布和语言模型概率分布,我们利用KL divergence来匹配两个分布,并用来优化dense retriever
L = 1 ∣ B ∣ ∑ x ∈ B K L ( P R ( d ∣ x ) ∥ Q L M ( d ∣ x , y ) ) \mathcal{L}=\frac{1}{|\mathcal{B}|} \sum_{x \in \mathcal{B}} K L\left(P_R(d \mid x) \| Q_{\mathrm{LM}}(d \mid x, y)\right) L=∣B∣1∑x∈BKL(PR(d∣x)∥QLM(d∣x,y))
因为检索器参数在训练过程中更新,参数更新后document embedding会变化,因此每隔 T T T步就重新算一次document embedding,并重复上述过程
所有训练数据都来自 Pile training data
(包含不同领域文本的language model benchmark)
800K 个 256 token长的序列作为训练queries
外部语料库 D D D, 采样36M 128 token长的文档
Pile training data
(367M documents of 128 tokens) and use them as the retrieval corpus for all modelsAtlas
trains both the retriever and the language model, which we consider a white-box retrieval LM setting.dataset: Natural Question
and TriviaQA
few-shot
(use a few training data) and full data
(use all training data)RETRO
, R2-D2
, Atlas
are finetuned on the training data, either in a few-shot setting or with full training data
REPLUG
和 REPLUG LSR
的性能单点提升,不过 a small number of documents(e.g., 10)就可以做的不错REPLUG
带来的性能增益与模型大小保持一致, 且能够应用到不同模型上REPLUG
is more helpful when texts contain rare entitiesit is unclear when the model relies on retrieved knowledge or parametric knowledge
target: understand when we should and should not rely on LMs’ parametric knowledge, and how scaling and non-parametric memories can help
Dimensions of Analysis:
Dataset:
PopQA
: randomly sample knowledge triples of 16 relationship types from Wikidata
EntityQuestions
: use Wikipedia hyperlink counts as a proxy of the frequency of entities and sample knowledge triples from WikiData
, from the frequency distributions
run an off-the-shelf retrieval system off-line to retrieve context from Wikipedia relevant to a question and concatenate the retrieved context(top one for simplicity) with the original question
BM25
/ Contriever
we use retrieval for questions whose popularity is lower than a threshold
LMs’ memorization (RQ1) is often limited to the popular factual knowledge and even GPT-3 davinci-003
fails to answer the majority of the long-tail questions
Non-parametric memories largely improve performance on long-tail distributions across models.
Devise a simple-yet-effective retrieval-augmented LM method, Adaptive Retrieval
, which adaptively combines parametric and non-parametric memories based on popularity