ChatGPT 中的人类反馈强化学习 (RLHF) 实战

目录

  • 1 前言
  • 2 人类反馈强化学习 (RLHF)
    • 2.1 奖励模型 (RM)
    • 2.2 近端策略优化算法 (PPO)
  • 3 总结
  • 4 参考


团队博客: CSDN AI小组


相关阅读

  • ChatGPT 简介
  • 大语言模型浅探一
  • 关于 ChatGPT 必看的 10 篇论文
  • 从 ELMo 到 ChatGPT:历数 NLP 近 5 年必看大模型

1 前言

在当今数字化的时代,ChatGPT 的火热程度不断升级。ChatGPT 可以处理复杂的语言任务,从而解放人力资源,提高工作效率,减少成本。ChatGPT 的先进技术和广泛应用,使得它成为了当今最炙手可热的人工智能技术之一。无论是企业、学术机构,还是科技爱好者,都对 ChatGPT 的应用前景充满期待。

在这样的背景之下,CSDN AI 团队也想对 ChatGPT 进行简单的复现。根据ChatGPT官方博客可知,ChatGPT的训练方法与 InstructGPT 的训练方法基本一致 (如图1所示),只是使用的数据集不一样。故在训练方法上,我们主要参考 InstructGPT 进行复现,基础模型使用的是 RWKV,拆分后共包含以下四个阶段:

  • (1) 语言模型预训练 (Language Model Pre-training);
  • (2) 有监督指令微调 (Supervised Fine-Tuning, SFT);
  • (3) 奖励模型的训练 (Reward Modeling, RM);
  • (4) 使用近端策略优化算法进行强化学习 (Proximal Policy Optimization, PPO).

第 (1)、(2) 阶段的 Pre-training 和 SFT 由 @zxm2015 完成,可参考文章大语言模型浅探一。本文主要介绍第 (3)、(4) 阶段的内容,即人类反馈强化学习 (Reinforcement Learning from Human Feedback, RLHF)。

图1 InstructGPT 模型的训练过程

2 人类反馈强化学习 (RLHF)

人类反馈强化学习 (RLHF) 是 ChatGPT 中一种用于改善其回答效果的算法。它是一种基于强化学习的方法,通过结合人类反馈来优化 ChatGPT 的回答。

在 RLHF 中,ChatGPT 学习通过和人类用户的交互来提高其回答的质量。当 ChatGPT 生成一个回答时,它会将回答展示给用户并请求用户的反馈。用户可以对回答进行评分,比如“好”、“不错”、“一般”、“差”等。ChatGPT 会将用户的反馈作为奖励或惩罚信号,以此来更新自己的模型,以更好地满足用户的需求。

RLHF 可分为两个部分。第一部分是奖励模型,人类反馈主要就体现在这个地方;第二部分采用近端策略优化算法的强化学习阶段,基于奖励模型的反馈来优化模型,最终得到满足人类偏好的语言模型。下面将对这两个部分进行详细的说明。

2.1 奖励模型 (RM)

在 RLHF 之前,语言模型已经进行了 SFT (后续称该模型为 SFT Model),而奖励模型的任务主要是对 SFT Model 的回复进行打分,打分越高表示回答效果越好。训练好奖励模型之后,就可以用于下一阶段的 PPO 进行强化学习的调优,奖励模型是 PPO 中的一个子部分,用于 PPO 训练时提供奖励信号。

(1) 模型的输入输出
模型的输入是用户提问 (Prompt) 和 SFT Model 回复 (Response) 的 pair 对 ,输出是一个奖励得分,如下图所示:

图2 RM 的输入和输出

(2) 数据集的构建
这个阶段主要是通过人工标注训练数据,来训练 RM,人类反馈就体现在这个地方。在 Prompts 数据集中随机抽取问题,对于每个问题,生成 K 个不同的回答。人类标注者对这些结果综合考虑(例如:相关性、富含信息性、有害信息等诸多标准)给出排名顺序。

按照上述奖励模型的输入输出描述,构建数据集时应该是人工对 进行打分,但实际上对多个回答进行打分比较困难,得分是连续的,这会降低标注的速度。此外,我们其实关注的是多个选项之间哪个更好,哪个更差。所以标注的时候对多个选项进行排序就可以了,最后基于排序后的回答,构建数据集,选用合适的损失函数即可。

通常情况下,人类进行排序任务,当选项为 4-9 个 (即 K∈{4, 5, 6, 7, 8, 9}) 时速度最快且效果最准确,此处我们设定 K=4。最终一个 Prompt 我们就可以得到 C(4, 2)=6 条训练样本。

具体而言,假设我们选定了一个问题 x,接着使用 SFT Model 生成了 4 个回答 {y1, y2, y3, y4},人类标注者进行排序后为 y4 > y3 > y1 > y2},则得到的训练样本如下所示,左边的得分要高于右边:

(, )
(, )
(, )
(, )
(, )
(, )

(3) 损失函数
根据上面构建的数据集可知,我们没有连续的得分目标去训练奖励模型,但是有正负例样本对,所以损失函数如下所示,该损失函数需要最小化:
其中,r(x,y) 为 输入到 RM 模型的得分,θ 是 RM 的参数,yw 和 yl 是输入为 x 时,SFT Model 生成的不同回答,其中人工标注时 yw > yl.

# loss function
def loss_function(prefer_reward, alter_reward):
    return -torch.mean(torch.log(torch.sigmoid(prefer_reward - alter_reward)))

(4) 核心代码
RM 的网络结构相比于 SFT Model,并不需要做太大的改动,输入 后,直接取最后一个 token 的 embedding,在其后面接一个线性层计算奖励得分即可

a) 线性层:

# reward 得分计算
self.pred_reward = nn.Linear(dim, 1, bias=False)

b) forword 函数

    def forward(
        self,
        x,
        mask = None,
        prompt_mask = None,
        prompt_lengths = None
    ):

        # prompt_mask 和 prompt_lengths 只能二选一
        assert not (exists(prompt_mask) and exists(prompt_lengths))

        # derive prompt mask from prompt lengths
        if exists(prompt_lengths):
            batch, seq_len = x.shape
            arange = torch.arange(seq_len, device=x.device)
            prompt_mask = repeat(arange, 'n -> b n', b = batch) < rearrange(prompt_lengths, 'b -> b 1')

        # reward model should have an understanding of which section is prompt, and which section is response
        # 根据 prompt_mask 中 token 的 True 和 False,从 prompt_embed 或 response_embed 中取值
        # 如果为 True,则从 prompt_embed 中选,否则从 response_embed 中选
        prompt_response_mask_embed = torch.stack([
            self.prompt_embed,
            self.response_embed,
            self.padding_embed
        ]).to(prompt_mask.device)
        extra_embed = None
        if exists(prompt_mask):
            extra_embed = prompt_response_mask_embed[prompt_mask]            

        # 获得最后一个 token 的 embedding
        last_token_embeds = self.rwkv(
            x,
            extra_embed=extra_embed,
            rm_train=True
        )[:, -1, :]

        # 计算奖励
        reward = self.pred_reward(last_token_embeds)
        reward = reward.squeeze(-1)

        return reward

c) train_forward 函数

    def train_forward(self, x_p, x_a, m_p, m_a):
        # 因为前向传播的时候,需要过两次模型。所以反馈的时候需要冻结其中一次的参数
        # 不然梯度会被计算两次,在包含 deepspeed 框架下会报错
        # 报错信息:Gradient computed twice for this partition.

        with torch.enable_grad():
            prefer_reward = self.forward(x_p, prompt_mask=m_p)
        with torch.no_grad():
            alter_reward = self.forward(x_a, prompt_mask=m_a)

        return prefer_reward, alter_reward

2.2 近端策略优化算法 (PPO)

近端策略优化算法(Proximal Policy Optimization, PPO)是一种深度强化学习算法,其目标是学习一个能够最大化长期累积回报的策略。

图3 PPO 训练架构详细版本

(1) PPO算法包含以下几个主要部分:

  • a) 策略网络 (Policy Network)
    用于学习并输出给定状态下不同行动的概率分布。它通常是一个神经网络,可以根据环境的反馈进行更新。对应图3中的 Actor,使用 SFT Model 进行初始化,在 PPO 中需要参与训练。

  • b) 价值网络 (Value Network)
    用于预测给定状态的预期回报值。它通常也是一个神经网络,它的输出可以用来计算优势函数,从而帮助更新策略网络。对应图3中的 Critic,使用 RM 进行初始化,在 PPO 中需要参与训练。

  • c) 奖励模型
    对应图3中的 Reward Model,是 2.1 节中训练得到的模型,在 PPO 中不参与训练,只提供奖励信号,用于 PPO 的训练。

  • d) SFT Model
    对应图3中的 Supervised Fine-Tune Model,用于更新策略网络,以使其能够产生更好的策略。通过限制每次更新的幅度,从而确保更新后的策略与原始策略之间的差异不会太大。该部分可以参与训练,也可以不参与,当参与训练时,PPO 被称为 PPO-ptx。

  • e) 经验采样
    用于收集与环境交互的经验数据,以供策略网络和价值网络的更新使用。在PPO算法中,经验采样通常采用基于行动价值估计的策略。对应图3中顶部的 Prompts -> Actor -> Response 流程。

图4 PPO 训练架构简化版本

(2)损失函数

  • a) actor loss (也称为 policy loss, 是最终要使用模型的 loss)
    其中,πRL 是 actor,πSFT 是已经训练好的 SFT Model。损失函数的第1项和第2项是核心部分,第3项是可选项。该损失函数需要最大化。具体如下:
    • 第一项:这一项是奖励模型 RM 奖励得分,奖励需要最大化;
    • 第二项:这一项被用于惩罚 RL 策略在每个训练批次中生成大幅偏离初始模型,以确保模型输出合理连贯的文本。如果去掉这一惩罚项可能导致模型在优化中生成乱码文本来愚弄奖励模型提供高奖励值;
    • 第三项:这一项是预训练梯度 (可选项),传统的 PPO 中一般不包含该项,InstructGPT 中加入这一项是为了避免 RLHF 导致大模型在公开的 NLP 评测任务上效果下降。加入该项之后被命名为 PPO-ptx。
  • b) critic loss (也称为 value loss)
    使用的是 clipped_value_loss。

(3)核心代码
a) training_step

    def training_step(self, batch, batch_idx, optimizer_idx):
        sequences, \
        prompt_masks, \
        masks, \
        old_action_probs, \
        old_log_probs, \
        rewards, \
        old_values = batch

        # PPO training
        action_masks = ~prompt_masks & masks

        action_logits, values = self.actor_critic(
            sequences,
            mask = action_masks
        )

        action_logits = shift(action_logits, shift=1, dim=-2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
        action_len = old_log_probs.shape[-1]

        action_probs = action_logits.softmax(dim = -1)
        action_log_probs = log_prob(action_probs, sequences)
        action_log_probs = action_log_probs[:, -action_len:]

        # calculate entropies, taking into account which part of the sequence is actually an action

        entropies = masked_entropy(action_probs, mask = action_masks)

        # calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not

        kl_div_loss = 0.

        if self.args.kl_div_loss_weight > 0:
            kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.args.kl_div_loss_weight

        # handle non-pooled values

        normalize_kwargs = dict()

        if old_values.ndim == 2:
            old_values, values = map(lambda t: shift(t, shift = 1, dim = -2), (old_values, values))

            old_values = old_values[:, -action_len:]
            values = values[:, -action_len:]
            rewards = rearrange(rewards, 'b -> b 1')
            normalize_kwargs = dict(dim = -1, mask = action_masks[:, -action_len:])

        if values.ndim < rewards.ndim:
            values = rearrange(values, '... -> ... 1')

        # calculate clipped surrogate objective, classic PPO loss

        ratios = (action_log_probs - old_log_probs).exp()
        advantages = masked_normalize(rewards - old_values, **normalize_kwargs)

        if advantages.ndim == 1:
            advantages = rearrange(advantages, 'b -> b 1')

        surr1 = ratios * advantages
        surr2 = ratios.clamp(1 - self.args.eps_clip, 1 + self.args.eps_clip) * advantages
        policy_loss = - torch.min(surr1, surr2) - self.args.beta_s * entropies

        # actor loss (也称为 policy loss, 是最终要使用模型的 loss)
        if optimizer_idx == 0:
            actor_loss = policy_loss.mean() + kl_div_loss
            return actor_loss

        # critic loss (也称为 value loss)
        # update value network separate from policy network
        if optimizer_idx == 1:
            critic_loss = clipped_value_loss(values, rewards, old_values, self.args.value_clip)
            critic_loss = critic_loss.mean()
            return critic_loss

b) gen_experience_dataset

    def gen_experience_dataset(self):
        ''' 通过与 environment 交互产生训练数据
        '''
        
        device = self.device

        time_cnt = 0
        for eps in tqdm(range(self.args.num_episodes), desc = 'episodes'):
            for timestep in range(self.args.max_timesteps):
                time_cnt += 1

                # select a bunch of random states (prompts)
                # and get the action (sampled sequence from rwkv as well as the action probs)
                # also calculate the reward using reward model and store
                # 随机挑选一条 prompt
                rand_prompt_index = randrange(0, len(self.prompts))
                state = self.prompts[rand_prompt_index]

                # remove padding from state
                state_mask = state != self.args.pad_value
                state = state[state_mask]

                # get predicted sequence
                # 与 environment 进行交互,其中返回的:
                #   action 是 response,
                #   sequence 是 prompt + response, 
                (
                    actions,
                    sequence,
                    mask,
                    prompt_mask,
                    action_logits,
                    value
                ) = self.actor_critic.generate(
                    rearrange(state, 'n -> 1 n'),
                    max_seq_len = self.args.ctx_len,
                    return_values = True
                )
                action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token

                action_prob = action_logits.softmax(dim = -1)

                action_len = actions.shape[-1]
                action_log_prob = log_prob(action_prob, sequence)
                action_log_prob = action_log_prob[:, -action_len:]

                actions = rearrange(actions, '1 ... -> ...')

                # get reward as given by supervised trained reward model
                sequence = torch.cat((state, actions), dim = 0)

                prompt_length = len(state)
                prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length

                sequence = rearrange(sequence, 'n -> 1 n')
                prompt_mask = rearrange(prompt_mask, 'n -> 1 n')
                mask = rearrange(mask, 'n -> 1 n') if exists(mask) else torch.ones(sequence.shape, dtype = torch.bool, device = device)

                reward = self.reward_model(
                    sequence,
                    prompt_mask = prompt_mask,
                    mask = mask,
                    sample = True
                )

                self.sequence_batch.append(sequence)
                self.prompt_mask_batch.append(prompt_mask)
                self.mask_batch.append(mask)
                self.action_prob_batch.append(action_prob)
                self.action_log_prob_batch.append(action_log_prob)
                self.reward_batch.append(reward)
                self.value_batch.append(value)

                if time_cnt % self.args.update_timesteps == 0:
                    train_data = zip(
                        self.sequence_batch, self.prompt_mask_batch, self.mask_batch, 
                        self.action_prob_batch, self.action_log_prob_batch, self.reward_batch, 
                        self.value_batch
                    )

                    for _sequence, _prompt_mask, _mask, _action_prob, _action_log_prob, _reward, _value in train_data:
                        yield _sequence, _prompt_mask, _mask, _action_prob, _action_log_prob, _reward, _value

                    self.sequence_batch.clear()
                    self.prompt_mask_batch.clear()
                    self.mask_batch.clear()
                    self.action_prob_batch.clear()
                    self.action_log_prob_batch.clear()
                    self.reward_batch.clear()
                    self.value_batch.clear()

3 总结

RLHF 可以根据用户反馈不断学习和优化对话,从而提高对话的质量和效果。但是由于算力资源的限制,我们只是简单调试并拉通了 RLHF 的训练流程,暂未在实际的数据集上训练模型。如若有纰漏指出,还请指正,感谢!

4 参考

[1] InstructGPT
[2] ChatGPT 背后的“功臣”——RLHF 技术详解
[3] ColossalAI
[4] PaLM-rlhf-pytorch
[5] Promixal Policy Optimization with PyTorch
[6] How ChatGPT Works Part 2: The Reward Model

你可能感兴趣的:(博客质量分测试,chatgpt,人工智能,深度学习)