今天没什么要写的碎碎念,因为话题太过私人了所以我转移到github的博客上写了。昨天看分布式训练看到了很多不错的想法,下午找论文也寻找到了不错的idea,总之这场学习没有白费力气,很好。一定要去积极的搜索查找学习资料。
不多说废话了,开始今天的学习。
这一章是“新的模型架构”。我不知道什么是“新的模型架构”。目前来讲多模态clip也用的是transformer。
回想一下第7章 模型架构,神经语言模型的核心接口是一个将token序列映射到上下文嵌入的编码器:
[ the , mouse , ate , the , cheese ] ⇒ ϕ [ ( 1 0.1 ) , ( 0 1 ) , ( 1 1 ) , ( 1 − 0.1 ) , ( 0 − 1 ) ] . [\text{the}, \text{mouse}, \text{ate}, \text{the}, \text{cheese}] \stackrel{\phi}{\Rightarrow} \left[\binom{1}{0.1}, \binom{0}{1}, \binom{1}{1}, \binom{1}{-0.1}, \binom{0}{-1} \right].\ [the,mouse,ate,the,cheese]⇒ϕ[(0.11),(10),(11),(−0.11),(−10)].
以GPT-3为例,它是一个通过堆叠96层Transformer block,映射token序列 x 1 : L x_{1:L} x1:L的神经语言模型:
GPT-3 ( x 1 : L ) = TransformerBlock 96 ( EmbedTokenWithPosition ( x 1 : L ) ) , \text{GPT-3}(x_{1:L}) = \text{TransformerBlock}^{96}(\text{EmbedTokenWithPosition}(x_{1:L})), GPT-3(x1:L)=TransformerBlock96(EmbedTokenWithPosition(x1:L)),
其中,每层Transformer block使用
TransformerBlock ( x 1 : L ) = AddNorm ( FeedForward , AddNorm ( SelfAttention , x 1 : L ) ) . \text{TransformerBlock}(x_{1:L}) = \text{AddNorm}(\text{FeedForward}, \text{AddNorm}(\text{SelfAttention}, x_{1:L})). TransformerBlock(x1:L)=AddNorm(FeedForward,AddNorm(SelfAttention,x1:L)).
先验知识:
现状:
GPU1 [ layer1 , layer2 ] GPU2 [ layer3 , layer4 ] GPU3 [ layer5 , layer6 ] . \text{GPU1}[\text{layer1}, \text{layer2}] \quad\quad\quad \text{GPU2}[\text{layer3}, \text{layer4}] \quad\quad\quad \text{GPU3}[\text{layer5}, \text{layer6}]. GPU1[layer1,layer2]GPU2[layer3,layer4]GPU3[layer5,layer6].
在本章中,我们将探讨两种不同类型的“新”模型架构,这提高了模型的规模上限。特别地,我们将讨论:
input ⇒ expert 1 expert 2 expert 3 expert 4 ⇒ output . \text{input} \quad\quad\Rightarrow\quad\quad \text{expert}_1 \quad \text{expert}_2 \quad \text{expert}_3 \quad \text{expert}_4 \quad\quad\Rightarrow\quad\quad \text{output}. input⇒expert1expert2expert3expert4⇒output.
store ∣ input ⇒ relevant data from store ⇒ output . \text{store} \quad\quad|\quad\quad \text{input} \quad\quad\Rightarrow\quad\quad \text{relevant data from store} \quad \quad\quad\Rightarrow\quad\quad \text{output}. store∣input⇒relevant data from store⇒output.
混合专家系统(Mixture of Experts, MoE)是在神经网络 (Neural Network, NN) 领域发展起来的一种集成学习(Ensemble Learning) 技术。传统的深度学习模型在训练时,对于每个输入样本,整个网络都会参与计算。随着模型越来越大,训练使用的样本数据越来越多,训练的开销越来越难以承受。
而 MoE 可以动态激活部分神经网络,从而实现在不增加计算量的前提下大幅度增加模型参数量。
MoE 技术目前是训练万亿参数量级模型的关键技术。MoE 将预测建模任务分解为若干子任务,在每个子任务上训练一个专家模型(Expert Model),开发一个门控模型(Gating Model),该模型根据要预测的输入来学习信任哪个专家,并组合预测结果。
尽管该技术最初是使用神经网络专家和门控模型来描述的,但它可以推广到使用任何类型的模型。
大规模预训练模型可以隐式的编码知识并应用于下游任务,为开放域问答、对话、摘要等任务带来了巨大的性能提升。然而不断扩大语言模型参数量以及训练数据规模也可能带来如下问题:
低效。模型规模数量级增大带来的性能增益可能越来越小。
静态。预训练模型编码隐式知识的方式难以对具有时效性的知识进行调整。
不透明。研究者很难判断模型本身掌握了什么知识,在完成任务时用到了哪些知识。在很多下游任务中,语言模型可能会产生事实幻觉(Fact hallucination)。
为了解决上述问题,研究者们提出了改进任务表现的另一种思路:基于检索的自然语言模型。
这类模型可以在外部知识库中搜索所需信息,结合外部知识以及语言模型的本身优势完成任务。
MoE 包含四个要素:
一些预测建模任务非常复杂,因此需要尽可能的把它们划分为子任务来处理。这是解决问题的一种分治方法,是许多预测建模自动化方法以及更广泛地解决问题的基础。
例如,可以根据问题的一些领域知识将输入特征空间划分为子空间。然后可以在问题的每个子空间上训练模型,实际上是特定子问题的专家。然后,模型会学习调用哪个专家来预测未来的新示例。
第一步是将预测建模问题划分为子任务。
这通常涉及使用领域知识。例如,可以将图像划分为单独的元素,例如背景、前景、对象、颜色、线条等。
MoE 采用分治的策略,将一项复杂的任务分解为几个更简单、更小的子任务,并针对不同的子任务开发个体学习者(称为专家)进行训练。
在 MoE 系统中,一个关键问题是如何找到任务的自然划分,然后从子解决方案中得出整体解决方案。
接下来,为每个子任务设计一个专家。
MoE 方法最初是在人工神经网络领域开发和探索的,因此传统上,专家本身是用于预测回归情况下的数值或分类情况下的类别标签的神经网络模型。
专家可以是任何模型,例如:支持向量机器 (Support Vector Machines, SVM)、高斯过程 (Gaussian processes, GP) 、隐藏马尔可夫模型(hidden Markov models, HMM)、卷积神经网络(Convolutional Neural Networks, CNN)、Transformer、ViT(Vision Transformer)。
门控模型用于解释每个专家所做的预测,并帮助决定对给定输入信任哪个专家。这被称为门控模型或门控网络,因为它传统上是一个神经网络模型。
门控网络将提供给专家模型的输入模式作为输入,并输出每个专家在对输入进行预测时应该做出的贡献。由门控网络确定的权重是根据给定的输入动态分配的,因为 MoE 系统有效地学习了每个集成成员学习了特征空间的哪一部分。
门控网络是 MoE 的关键,并且门控模型有效地学习为给定输入选择类型子任务,反过来,专家可以信任以做出强有力的预测。
MoE 也可以看作是一种分类器选择算法,其中单个分类器被训练成为特征空间某些部分的专家。
当使用神经网络模型时,门控网络和专家一起训练,以便门控网络学习何时信任每个专家进行预测。这种训练过程传统上是使用期望最大化 (Expectation Maximization, EM) 来实现的。门控网络可能有一个 softmax 输出,它为每个专家提供类似概率的置信度分数。
一般来说,训练过程试图实现两个目标:
对于给定的专家,找到最优的门控函数;对于给定的门控函数,针对门控函数指定的分布训练专家。
MoE 必须做出预测,这是通过池化或聚合机制实现的。这可能就像选择门控网络提供的具有最大输出或置信度的专家一样简单。或者,可以进行加权和预测,明确地结合每个专家的预测和门控网络估计的置信度。也可能存在其他有效利用预测和门控网络输出的方法。然后,池化/组合系统可以选择具有最高权重的单个分类器,或者计算每个类的分类器输出的加权和,并选择接收最高加权和的类。
混合专家系统 (Mixture of Experts),简称 MoE 或 ME,是一种集成学习技术,它实现了在预测建模问题的子任务上培训专家的想法。在神经网络社区中,研究人员研究了分解输入空间的 MoE 方法,以便每个专家检查空间的不同部分,门控网络负责组合各种专家。
在 MoE 架构中,一组专家和一个门控相互合作,通过将输入空间划分为一组嵌套的区域来解决非线性监督学习问题。门控对整体输入空间进行软分割,专家模型在这些区域的分区中学习特定的参数。
可以使用期望最大化 (Expectation Maximization, EM) 算法来学习专家模型和门控模型中的这些参数。
混合专家的想法可以追溯到Jacobs et al. (1991)。
首先MoE是一个层,而不是一整个模型。其次,正如我们刚才所说,这个模型结构包含一个门网络来决定激活哪个expert,同时包含n个expert网络,这n个expert网络一般是同结构的。
为了介绍基本思想,假设我们正在解决一个预测问题:
x ∈ R d ⇒ y ∈ R d . x \in \mathbb{R}^d \Rightarrow y \in \mathbb{R}^d. x∈Rd⇒y∈Rd.
让我们从学习前馈(ReLU)神经网络开始:
h θ ( x ) = W 2 max ( W 1 x , 0 ) , h_\theta(x) = W_2 \max(W_1 x, 0), hθ(x)=W2max(W1x,0),
其中参数为 θ = ( W 1 , W 2 ) \theta = (W_1, W_2) θ=(W1,W2)。
但专家的混合方法是:
定义 E E E个专家。
每个专家 e = 1 , … , E e = 1, \dots, E e=1,…,E都具有自己的嵌入 w e ∈ R d w_e \in \mathbb{R}^d we∈Rd。
将门控函数定义为 E E E个专家上的概率分布:
g e ( x ) = exp ( w e ⋅ x ) ∑ e ′ = 1 E exp ( w e ′ ⋅ x ) . g_e(x) = \frac{\exp(w_e \cdot x)}{\sum_{e' = 1}^E \exp(w_{e'} \cdot x)}. ge(x)=∑e′=1Eexp(we′⋅x)exp(we⋅x).
每个专家 e = 1 , … , E e = 1, \dots, E e=1,…,E都具有自己的参数 θ ( e ) = ( W 1 ( e ) , W 2 ( e ) ) \theta^{(e)} = (W_1^{(e)}, W_2^{(e)}) θ(e)=(W1(e),W2(e))。
根据专家特定参数定义每个专家函数:
h θ e ( x ) = W 2 ( e ) max ( W 1 ( e ) x , 0 ) . h_{\theta_e}(x) = W_2^{(e)} \max(W_1^{(e)} x, 0). hθe(x)=W2(e)max(W1(e)x,0).
f ( x ) = ∑ e = 1 E g e ( x ) ⏟ gating h θ e ( x ) ⏟ expert . f(x) = \sum_{e=1}^E \underbrace{g_e(x)}_\text{gating} \underbrace{h_{\theta_e}(x)}_\text{expert}. f(x)=e=1∑Egating ge(x)expert hθe(x).
当 G ( x ) i = 0 G(x)_i=0 G(x)i=0的时候,对应的expert就不会激活。
考虑d=2,并且每个专家都是一个线性分类器(来源):
我们可以通过反向传播来学习混合专家模型。根据链式法则,可以得到:
∇ f ( x ) = ∑ e = 1 E g e ( x ) ( ∇ ( log g e ( x ) ) h θ e ( x ) + ∇ h θ e ( x ) ) . \nabla f(x) = \sum_{e=1}^E g_e(x) (\nabla (\log g_e(x)) h_{\theta_e}(x) + \nabla h_{\theta_e}(x)). ∇f(x)=e=1∑Ege(x)(∇(logge(x))hθe(x)+∇hθe(x)).
注意到,梯度与 g e ( x ) g_e(x) ge(x)成比例,并且同时更新门控函数和专家。
g ( x ) = [ 0.04 , 0.8 , 0.01 , 0.15 ] . g(x) = [0.04, 0.8, 0.01, 0.15]. g(x)=[0.04,0.8,0.01,0.15].
正如公式所言,专家的混合不会节省任何计算,因为前向传播仍然需要评估每个专家,而反向传播也必须接触每个专家。
然而,如果我们将门控函数 g ( x ) = [ g 1 ( x ) , … , g E ( x ) ] g(x) = [g_1(x), \dots, g_E(x)] g(x)=[g1(x),…,gE(x)]近似为 g ~ ( x ) = [ g ~ 1 ( x ) , … , g ~ E ( x ) ] \tilde g(x) = [\tilde g_1(x), \dots, \tilde g_E(x)] g~(x)=[g~1(x),…,g~E(x)],其中大多数专家都是零。因此,在前向和反向传播时,我们只需要使用非零 g ~ e ( x ) \tilde g_e(x) g~e(x)的专家 e e e。
例如,我们可以选取值排名前两位(top 2)的专家,并重新规范化:
g ~ ( x ) = [ 0 , 0.84 , 0 , 0.16 ] . \tilde g(x) = [0, 0.84, 0, 0.16]. g~(x)=[0,0.84,0,0.16].
这种gate的方式会导致每个expert分到的样本太少。假设有n个experts,batch_size=b,每次会有k个expert被选择,每个expert会接收到平均kb/n << b个样本。
这里提出了一些解决方法:
数据并行和模型并行: 相当于变相的扩大b,假设有d个device,每个device上一次处理b个样本,那么在这次训练中,batch=bd,从而每个expert会接收kbd/n个样本。
单步拆分: 在我们的实验中,MoE中每个expert都是一个单层全连接,而这个层次是在LSTM层之间,因而,可以把训练LSTM时的多步给拆分开,从而相当于增大MoE训练的batch_size。
采用一些方法优化模型训练时的内存,从而进一步增大batch size。
MoETransformerBlock ( x 1 : L ) = AddNorm ( MoEFeedForward , AddNorm ( SelfAttention , x 1 : L ) ) . \text{MoETransformerBlock}(x_{1:L}) = \text{AddNorm}(\text{MoEFeedForward}, \text{AddNorm}(\text{SelfAttention}, x_{1:L})). MoETransformerBlock(x1:L)=AddNorm(MoEFeedForward,AddNorm(SelfAttention,x1:L)).
我们将top-2专家的近似门控函数定义如下:
计算第一个专家: e 1 = arg max e g e ( x ) e_1 = \arg\max_e g_e(x) e1=argmaxege(x)。
计算第二个专家: e 2 = arg max e ≠ e 1 g e ( x ) e_2 = \arg\max_{e \neq e_1} g_e(x) e2=argmaxe=e1ge(x)。
始终保留第一个专家,并随机保留第二个专家:
如果不做改进,那么这么多的expert,只有几个expert会被集中使用。为了改进这一问题,给每个expert定义了一个importance的概念。importance就是指这个expert所处理的样本数,简而言之就是G(x)对应位置的和。importance的损失函数则是importance的平方乘以一个系数。
loss = negative-log-likelihood + λ load-balancing-loss . \text{loss} = \text{negative-log-likelihood} + \lambda \text{load-balancing-loss}. loss=negative-log-likelihood+λload-balancing-loss.
例如,我们可以取 λ = 0.01 B \lambda = \frac{0.01}{B} λ=B0.01。
下面是一个 B = 2 B=2 B=2个token, E = 4 E=4 E=4个专家的例子:
g ( x 1 ) = [ 0.2 , 0.6 , 0.1 , 0.1 ] ⇒ g ~ ( x 1 ) = [ 0.25 , 0.75 , 0 , 0 ] g ( x 2 ) = [ 0.1 , 0.6 , 0.2 , 0.1 ] ⇒ g ~ ( x 2 ) = [ 0 , 0.75 , 0.25 , 0 ] g(x_1) = [0.2, 0.6, 0.1, 0.1] \Rightarrow \tilde g(x_1) = [0.25, 0.75, 0, 0] \\ g(x_2) = [0.1, 0.6, 0.2, 0.1] \Rightarrow \tilde g(x_2) = [0, 0.75, 0.25, 0] g(x1)=[0.2,0.6,0.1,0.1]⇒g~(x1)=[0.25,0.75,0,0]g(x2)=[0.1,0.6,0.2,0.1]⇒g~(x2)=[0,0.75,0.25,0]
统计为
c = [ 1 , 2 , 1 , 0 ] m = [ 0.3 , 1.2 , 0.3 , 0.2 ] c = [1, 2, 1, 0] \quad\quad\quad\quad m = [0.3, 1.2, 0.3, 0.2] c=[1,2,1,0]m=[0.3,1.2,0.3,0.2]
也就是说,我们会尝试降低专家2的权重,避免其被过度使用。
BASE需要更多的计算来优化 a a a,但更稳定。
示例:The nurse notified the patient that {her/his,their} shift would be ending in an hour.
GLaM的性别偏见少于GPT-3。
StereoSet上的结果:
The assistant went to work. {She brought her boss coffee., She was valued for her input.}
刻板印象随着模型大小的增加而变得更糟(与GLaM结果相反)。
现在,我们转向另一类语言模型,基于检索的(或检索增强的、记忆增强的模型),它可以帮助我们突破稠密Transformer的缩放上限。
Retrieval-augmented generation (RAG) is an AI framework for improving the quality of LLM-generated responses by grounding the model on external sources of knowledge to supplement the LLM’s internal representation of information. Implementing RAG in an LLM-based question answering system has two main benefits: It ensures that the model has access to the most current, reliable facts, and that users have access to the model’s sources, ensuring that its claims can be checked for accuracy and ultimately trusted.
In a 2020 paper, Meta (then known as Facebook) came up with a framework called retrieval-augmented generation to give LLMs access to information beyond their training data. RAG allows LLMs to build on a specialized body of knowledge to answer questions in more accurate way.
As the name suggests, RAG has two phases: retrieval and content generation.
In the retrieval phase, algorithms search for and retrieve snippets of information relevant to the user’s prompt or question. In an open-domain, consumer setting, those facts can come from indexed documents on the internet; in a closed-domain, enterprise setting, a narrower set of sources are typically used for added security and reliability.
This assortment of external knowledge is appended to the user’s prompt and passed to the language model. In the generative phase, the LLM draws from the augmented prompt and its internal representation of its training data to synthesize an engaging answer tailored to the user in that instant. The answer can then be passed to a chatbot with links to its sources.
Experience teaches us to stop and say when we don’t know something. But LLMs need to be explicitly trained to recognize questions they can’t answer.”
With enough fine-tuning, an LLM can be trained to pause and say when it’s stuck. But it may need to see thousands of examples of questions that can and can’t be answered. Only then can the model learn to identify an unanswerable question, and probe for more detail until it hits on a question that it has the information to answer.
RAG is currently the best-known tool for grounding LLMs on the latest, verifiable information, and lowering the costs of having to constantly retrain and update them. But RAG is imperfect, and many interesting challenges remain in getting RAG done right.
不严谨的就直接从2016年的工作起头了,其实那时候就有一个共识是,往往一些关键词出现过之后有非常大的概率重新出现,但是使用RNN,LSTM这种在进行语言建模的时候多多少少还是存在“记不住“的问题,这篇cached LM就是在之前正常的LM基础上引入了cache机制,即在语言建模时候也要从之前预测过的词语存的cache里面索引然后作为一项考虑因素。但是紧接着再2017年我们都知道这个长距离依赖问题部分被transformer的attention机制解决了,所以说这篇文章作为LM的热度并没有那么高。
当我们使用LM时候,我们到底允不允许它“拖家带口“,即训练完我们必须要把所有的训练数据扔掉,即“扔掉书本,闭卷考试“,还是说允许在预测时候允许带着所有训练数据(毕竟见一面见多面训练数据并不是犯错误),即“允许带着书进入考场,开卷考试”。这两种方式不仅仅是持续学习角度里面的不同的setting,从我们做NLP的角度来讲,实际的工业部署需求是一定要考虑在其中的,到底是把知识存到模型里面效率高一点还是把知识放在内存里面检索出来效率高,随着软硬件结构的改变随时洗牌完全有可能。
19年的时候Luke和Mike团队提出了一种knn-LM,看图就大概能明白思路,简而言之就是语言模型负责一部分下一个token(在他这里面叫target)的概率分布预测,即LM的target概率预测,然后**从语料库里面通过计算各个表示的相似度找到k个最相似的然后加权求得一个分布概率,即knn LM的target概率预测。**然后把他们加权求和得到概率预测。
把retrieve引入到语言模型的话其实每一个token都有一个key和value的(看图来理解),所以检索的数据是一个相当大的规模的数据。他这里面使用了FAISS(一个用来快速寻找k个最大相似向量的系统)来加速。
另一条线上,检索方法的重度使用区其实是开放域问答和机器阅读理解(著名的SQuAD benchmark等等都是开放域问答),也就是陈丹琦所在的领域。这几年检索增强的NLP论文里面大家经常提的一篇论文也是她在ACl 2017年的论文[9],当然我们这里面不拘泥于QA,只是去了解一下他当时的操作和思路。
系统分为两大块,首先有一个 Document Retriever,对给定的问题 question,从所有维基百科文章中检索(这里面注意SQuAD 实际上是有具体的位置的,但是WebQuestion等等数据集就是没有的,这里面她是通过远程监督的方式做的处理);检索到文章后切分成段落,然后用一个称之为 Document Reader 的模块在段落中预测答案位置并给出分数。后者其实就是标准的阅读理解模型了,完全可以替换成其他的机器阅读理解模型。
现在看来,至今也都是Retriever+Reader的这个路子,只不过,1)检索和理解的对象;2)检索器和理解器所使用的工具;3)检索端和理解端端侧重都沧海桑田地发生了变化。我们后面慢慢来看。
让我们首先关注使用编码器-解码器框架的序列到序列任务:
input x ⇒ output y \text{input } x \quad\Rightarrow\quad \text{output } y input x⇒output y
示例(开放问答):
回想一下,BART和T5是编码器-解码器模型的代表:
p ( y ∣ x ) p(y \mid x) p(y∣x)
其使用去噪目标函数进行训练。
例如:
输入 x x x:Thank you < X >
输出 y y y: < X >
假设我们有一个存储库 S S S,它是一组序列(通常是文档或段落)的集合。
S = { Why is the... , Thanks for , . . . , The quick... , Stanford... } . S = \{ \text{Why is the...}, \text{Thanks for}, ..., \text{The quick...}, \text{Stanford...} \}. S={Why is the...,Thanks for,...,The quick...,Stanford...}.
基于检索的模型直观的生成过程:
示例(开放问答):
最近邻是最常用的一种检索方法:
Generalization through memorization: Nearest neighbor language models [ICLR 2020]
这篇论文提出了一种将 kNN 最近邻模型与语言模型预测分布相结合的方法,
NN-LM。该方法可以有效提高预训练语言模型对例如事实知识等稀有文本形式的建模能力。
这篇论文在语言模型预训练阶段中引入了一个知识检索模型。相比于不能在下游任务中微调的
K N N − L M KNN-LM KNN−LM,REALM 能够在预训练、微调和推理阶段显式利用大型语料库中的知识。
形式上,RAG模型定义如下:
( y ∣ x ) = ∑ z ∈ S p ( z ∣ x ) ⏟ retriever p ( y ∣ z , x ) ⏟ generator . (y \mid x) = \sum_{z \in S} \underbrace{p(z \mid x)}_\text{retriever} \underbrace{p(y \mid z, x)}_\text{generator}. (y∣x)=z∈S∑retriever p(z∣x)generator p(y∣z,x).
在实践中, ∑ z ∈ S \sum_{z \in S} ∑z∈S由前k个代替(类似于为混合专家选择前1个或2个专家)。
Dense Passage Retrieval (DPR)** (Karpukhin et al., 2020)
p ( z ∣ x ) = exp ( BERT d ( z ) ⋅ BERT q ( x ) ) ∑ z ′ ∈ S exp ( BERT d ( z ′ ) ⋅ BERT q ( x ) ) . p(z \mid x) = \frac{\exp(\text{BERT}_\text{d}(z) \cdot \text{BERT}_\text{q}(x))}{\sum_{z' \in S} \exp(\text{BERT}_\text{d}(z') \cdot \text{BERT}_\text{q}(x))}. p(z∣x)=∑z′∈Sexp(BERTd(z′)⋅BERTq(x))exp(BERTd(z)⋅BERTq(x)).
p ( y ∣ z , x ) = p ( y ∣ concat ( z , x ) ) . p(y \mid z, x) = p(y \mid \text{concat}(z, x)). p(y∣z,x)=p(y∣concat(z,x)).
在Jeopardy问题生成任务上,输入Hemingway的检索结果:
实验结果表明,优于非检索方法:
这里引用GPT-3 few-shot的结果进行比较:NaturalQuestions (29.9%), WebQuestions (41.5%), TriviaQA (71.2%)
Improving Language Models by Retrieving from Trillions of Tokens
这项工作提出了一种从大规模语料库中检索文档块用于增强自回归模型的方法。相较于之前检索增强生成的方法,该论文使用了超大规模的检索语料库,表明了检索语料库规模的可扩展性。
首先,RETRO 建立了用于检索的数据库,其键为 BERT 编码表示,值为对应的文档块以及该文档块的下一个文档块。RETRO 的模型架构包含检索器,编码器以及带有块注意力(Chunked cross attention, CCA)模块的解码器。模型将输入也分为若干固定大小的块,基于已经生成的上下文及上下文分块检索到的文档块生成当前目标词。