团队博客: CSDN AI小组
相关阅读
在当今数字化的时代,ChatGPT 的火热程度不断升级。ChatGPT 可以处理复杂的语言任务,从而解放人力资源,提高工作效率,减少成本。ChatGPT 的先进技术和广泛应用,使得它成为了当今最炙手可热的人工智能技术之一。无论是企业、学术机构,还是科技爱好者,都对 ChatGPT 的应用前景充满期待。
在这样的背景之下,CSDN AI 团队也想对 ChatGPT 进行简单的复现。根据ChatGPT官方博客可知,ChatGPT的训练方法与 InstructGPT 的训练方法基本一致 (如图1所示),只是使用的数据集不一样。故在训练方法上,我们主要参考 InstructGPT 进行复现,基础模型使用的是 RWKV,拆分后共包含以下四个阶段:
第 (1)、(2) 阶段的 Pre-training 和 SFT 由 @zxm2015 完成,可参考文章大语言模型浅探一。本文主要介绍第 (3)、(4) 阶段的内容,即人类反馈强化学习 (Reinforcement Learning from Human Feedback, RLHF)。
人类反馈强化学习 (RLHF) 是 ChatGPT 中一种用于改善其回答效果的算法。它是一种基于强化学习的方法,通过结合人类反馈来优化 ChatGPT 的回答。
在 RLHF 中,ChatGPT 学习通过和人类用户的交互来提高其回答的质量。当 ChatGPT 生成一个回答时,它会将回答展示给用户并请求用户的反馈。用户可以对回答进行评分,比如“好”、“不错”、“一般”、“差”等。ChatGPT 会将用户的反馈作为奖励或惩罚信号,以此来更新自己的模型,以更好地满足用户的需求。
RLHF 可分为两个部分。第一部分是奖励模型,人类反馈主要就体现在这个地方;第二部分采用近端策略优化算法的强化学习阶段,基于奖励模型的反馈来优化模型,最终得到满足人类偏好的语言模型。下面将对这两个部分进行详细的说明。
在 RLHF 之前,语言模型已经进行了 SFT (后续称该模型为 SFT Model),而奖励模型的任务主要是对 SFT Model 的回复进行打分,打分越高表示回答效果越好。训练好奖励模型之后,就可以用于下一阶段的 PPO 进行强化学习的调优,奖励模型是 PPO 中的一个子部分,用于 PPO 训练时提供奖励信号。
(1) 模型的输入输出
模型的输入是用户提问 (Prompt) 和 SFT Model 回复 (Response) 的 pair 对
(2) 数据集的构建
这个阶段主要是通过人工标注训练数据,来训练 RM,人类反馈就体现在这个地方。在 Prompts 数据集中随机抽取问题,对于每个问题,生成 K 个不同的回答。人类标注者对这些结果综合考虑(例如:相关性、富含信息性、有害信息等诸多标准)给出排名顺序。
按照上述奖励模型的输入输出描述,构建数据集时应该是人工对
通常情况下,人类进行排序任务,当选项为 4-9 个 (即 K∈{4, 5, 6, 7, 8, 9}) 时速度最快且效果最准确,此处我们设定 K=4。最终一个 Prompt 我们就可以得到 C(4, 2)=6 条训练样本。
具体而言,假设我们选定了一个问题 x,接着使用 SFT Model 生成了 4 个回答 {y1, y2, y3, y4},人类标注者进行排序后为 y4 > y3 > y1 > y2},则得到的训练样本如下所示,左边
(
, )
(, )
(, )
(, )
(, )
(, )
(3) 损失函数
根据上面构建的数据集可知,我们没有连续的得分目标去训练奖励模型,但是有正负例样本对,所以损失函数如下所示,该损失函数需要最小化:
其中,r(x,y) 为
(4) 核心代码
RM 的网络结构相比于 SFT Model,并不需要做太大的改动,输入
a) 线性层:
b) forword 函数
c) train_forward 函数
近端策略优化算法(Proximal Policy Optimization, PPO)是一种深度强化学习算法,其目标是学习一个能够最大化长期累积回报的策略。
(1) PPO算法包含以下几个主要部分:
a) 策略网络 (Policy Network)
用于学习并输出给定状态下不同行动的概率分布。它通常是一个神经网络,可以根据环境的反馈进行更新。对应图3中的 Actor,使用 SFT Model 进行初始化,在 PPO 中需要参与训练。
b) 价值网络 (Value Network)
用于预测给定状态的预期回报值。它通常也是一个神经网络,它的输出可以用来计算优势函数,从而帮助更新策略网络。对应图3中的 Critic,使用 RM 进行初始化,在 PPO 中需要参与训练。
c) 奖励模型
对应图3中的 Reward Model,是 2.1 节中训练得到的模型,在 PPO 中不参与训练,只提供奖励信号,用于 PPO 的训练。
d) SFT Model
对应图3中的 Supervised Fine-Tune Model,用于更新策略网络,以使其能够产生更好的策略。通过限制每次更新的幅度,从而确保更新后的策略与原始策略之间的差异不会太大。该部分可以参与训练,也可以不参与,当参与训练时,PPO 被称为 PPO-ptx。
e) 经验采样
用于收集与环境交互的经验数据,以供策略网络和价值网络的更新使用。在PPO算法中,经验采样通常采用基于行动价值估计的策略。对应图3中顶部的 Prompts -> Actor -> Response 流程。
(2)损失函数
(3)核心代码
a) training_step
b) gen_experience_dataset
RLHF 可以根据用户反馈不断学习和优化对话,从而提高对话的质量和效果。但是由于算力资源的限制,我们只是简单调试并拉通了 RLHF 的训练流程,暂未在实际的数据集上训练模型。如若有纰漏指出,还请指正,感谢!
[1] InstructGPT
[2] ChatGPT 背后的“功臣”——RLHF 技术详解
[3] ColossalAI
[4] PaLM-rlhf-pytorch
[5] Promixal Policy Optimization with PyTorch
[6] How ChatGPT Works Part 2: The Reward Model