文章给出了一种基于人类指令对大语言模型进行调整的方法。在人类标注的数据上对GPT-3进行微调,并通过人类打分的数据对上述模型进行强化学习,从而得到一个满足人类偏好的语言模型。文章的整体架构是ChatGPT的前身,相较于GPT-3,回答真实性更高,且危害信息更少。
文章的整体架构如下
文章架构中的第一部分Step1采用了SFT方法。为此,文章首先雇用了一些标记员对prompt数据进行回答,称为labeler demonstrations,然后在demonstrations上面对GPT-3进行有监督的微调。微调后的模型我们记作SFT。
Prompt数据集通过如下方式构建:从OpenAI的Playground API(非生产环境)上面获取用户提交的prompt,其中保证使用每个用户的prompt数量不超过200条,并过滤敏感信息和重复信息(通过长的公共前缀过滤),然后按照user ID划分train, test, val。另外由于API的prompts的多样性较为单一,我们让我们的标记员自己写一些如下类别的prompts以初始化模型:
接下来我们训练一个RM打分模型,为此首先需要一个评分的数据集。OpenAI 构建了一个用户打分的UI,如下图所示。首先上述的SFT模型会对每个prompt生成 K K K个候选答案( 4 ≤ K ≤ 9 4\le K \le 9 4≤K≤9),用户会对每个答案打分,并对所有答案进行排序。页面的示例如下图所示。为了避免模型overfit且保证训练效率,模型训练时按照prompt对数据集划分,即保证每个prompt的所有 ( K 2 ) \tbinom K2 (2K)答案对在同一个batch。
RM模型根据上述打分数据进行训练,具体训练方案如下图。损失函数为 l o s s ( θ ) = − 1 ( K 2 ) E ( x , y w , y l ) ∼ D [ log ( σ ( r θ ( x , y w ) − r θ ( x , y l ) ) ] loss(\theta) = -\frac 1{\tbinom K2} E_{(x, y_w, y_l)\sim D} [\log (\sigma (r_{\theta} (x, y_w) - r_{\theta} (x, y_l) )] loss(θ)=−(2K)1E(x,yw,yl)∼D[log(σ(rθ(x,yw)−rθ(x,yl))],其中 r θ ( x , y ) r_{\theta} (x, y) rθ(x,y)为当前训练的RM模型的输出, x , y x, y x,y分别为prompt及其对应的回答, y w y_w yw为 y w , y l y_w, y_l yw,yl中更受用户喜爱的回答(ranking更高的)。训练的目的为使得loss尽可能小,即另每个 ( r θ ( x , y w ) − r θ ( x , y l ) ) (r_{\theta} (x, y_w) - r_{\theta} (x, y_l) ) (rθ(x,yw)−rθ(x,yl))尽可能大,即保证模型对 y w y_w yw打分高于 y l y_l yl越多越好。
接下来,我们用SFT模型预测每个prompt,并用RW预测人类偏好(rankings),再通过RL方法更新SFT的参数。文章在Stiennon[1]的工作基础上增加了PPO梯度,从而保证模型不损失原始GPT-3的回归性能。具体的目标函数为 o b j e c t i v e ( ϕ ) = E ( x , y ) ∈ D π ϕ [ r θ ( x , y ) − β log ( π ϕ R L ( y ∣ x ) / π S F T ( y ∣ x ) ) ] + γ E x ∼ D p r e t r a i n [ log ( π ϕ R L ( x ) ) ] objective(\phi) = E_{(x, y) \in D_{\pi_\phi}}[r_{\theta}(x, y) - \beta\log(\pi_\phi^{RL}(y|x)/\pi^{SFT}(y|x))] + \gamma E_{x\sim D_{pretrain}} [\log (\pi_\phi^{RL}(x))] objective(ϕ)=E(x,y)∈Dπϕ[rθ(x,y)−βlog(πϕRL(y∣x)/πSFT(y∣x))]+γEx∼Dpretrain[log(πϕRL(x))],其中 π ϕ R L \pi_\phi^{RL} πϕRL为当前学习的RL策略的输出, π S F T \pi^{SFT} πSFT为SFT模型的输出, D p r e t r a i n D_{pretrain} Dpretrain为预训练模型GPT-3的分布, β \beta β表示KL奖励系数用于控制KL惩罚项, γ \gamma γ表示预训练模型损失系数用于控制预训练模型的梯度更新。RL训练的具体方案如下图
文章在GPT-3的基础上,通过有监督的微调SFT和基于人类反馈的模型训练RM,得到了一个更符合人类偏好的大语言模型。文章的重点突破为采用RL方法增强了模型的可信度,降低模型输出危害回答的概率。InstructGPT是ChatGPT的前身,是Chat GPT面世必不可少的一步。
Training language models to follow instructions with human feedback
[1] Learning to summarize from human feedback
[2] GPT-3原文:Language Models are Few-Shot Learners
[3] GPT-3论文笔记