大语言模型-RLHF(四)-PPO(Proximal Policy Optimization)原理&实现&代码逐行注释

 前言

从open AI 的论文可以看到,大语言模型的优化,分下面三个步骤,SFT,RM,PPO,我们跟随大神的步伐,来学习一下这三个步骤和代码实现,本章介绍PPO原理

大语言模型-RLHF(四)-PPO(Proximal Policy Optimization)原理&实现&代码逐行注释_第1张图片


要搞明白PPO首先需要搞明白下面几个概念

一,策略梯度(Policy Gradient)

策略梯度(Policy Gradient)是一种用于强化学习中的策略优化方法,其核心思想是直接优化策略函数。策略函数可以理解为一个神经网络π(a∣s),描述的是在给定状态s下,采取不同动作a的概率分布。θ可以理解为策略神经网络π(a∣s)的参数,我们需要优化的就是这个θ,策略梯度的公式如下:

\nabla_{\theta} J(\theta) = E_{\tau \sim p_{\theta}(\tau)}[\sum_{t=0}^T \nabla_{\theta} \log \pi_{\theta}(a_t|s_t) A_t]

通常使用梯度上升法来更新策略函数θ,使其能够最大化期望回报。

我们改写一下θ的更新公式如: \theta_{t+1} = \theta_t + \alpha \nabla_{\theta} \log \pi_{\theta}(a_t|s_t) G_t

其中θ表示策略函数的参数是我们优化的目标,st​表示状态,at​表示动作,Gt​表示从时刻t开始的回报总和,α表示学习率,控制每次更新的步长大小。通过梯度上升法我们不断更新θ

可以看到上面θ是串行更新的,所以耗时比较久。因此我们引入 Off Policy的概念

二,On Policy Off Policy

On-policy和Off-policy是强化学习中两种不同的sample data的学习方式。

On-policy学习是指使用同一个策略来采集经验数据和更新价值函数或者优化策略。在On-policy学习中,我们使用当前策略来采集经验数据,然后使用这些数据来更新价值函数或者优化策略。由于采集数据和更新策略使用的是同一个策略,因此On-policy学习通常比较稳定,问题是收敛速度较慢。刚刚(Policy Gradient)很显然是个On-policy的方法

Off-policy学习是指使用不同的策略来采集经验数据和更新价值函数或者优化策略。在Off-policy学习中,我们使用一个策略(称为行为策略)来采集经验数据,然后使用另一个策略(称为目标策略)来更新价值函数或者优化策略。由于采集数据和更新策略使用的是不同的策略,因此Off-policy学习通常比较灵活,但是可能会导致采样偏差的问题。

大语言模型-RLHF(四)-PPO(Proximal Policy Optimization)原理&实现&代码逐行注释_第2张图片

画了个简单的示意图如上,同样经过三轮迭代。On-policy需要等待三次,而Off-policy可以并行开始,只需要一次就可以达到目的。大大提升了效率。

off policy公式如下,对比上一节on policy可以看到J(θ是由另外u分布采样得到的,而不是θ。

J(\theta) = E_{\tau \sim p_{\mu}(\tau)}[\sum_{t=0}^T \rho_t A_t \log \pi_{\theta}(a_t|s_t)]

其中,τ表示一个轨迹,pμ​(τ)表示根据策略函数πθ​(at​∣st​)生成的轨迹的概率分布,at​表示在状态st​下选择的动作,At​表示在状态st​下选择动作at​后获得的优势值,

另外对比上一节公式可以看到,除了分布采样变化还多了个ρt​=πθ​(at​∣st​)/πμ​​(at​∣st​),这是因为off policy会有采样偏差的问题,所以需要引入一个概念重要性采样,来解决这个问题

三,重要性采样(Importance Sampling):

刚刚提到,off policy会有采样偏差的问题,所以引入一个概念重要性采样,那么什么是重要性采样呢?

重要性采样(Importance Sampling)是一种通用的用于估计期望值的技术,常用于蒙特卡罗积分和概率分布采样等问题。重要性采样的基本思想是,通过从一个分布中采样,来估计另一个分布中的期望值。其公式如下

$E_{p(x)}[f(x)] \approx \frac{1}{N} \sum_{i=1}^{N} f(x_i) \frac{p(x_i)}{q(x_i)}$

其中,xi表示从简单分布 q(x)中采样得到的样本, N表示采样的样本数。

需要注意的是,为了保证估计的准确性,两个分布我们应该让他们尽量接近。否则,如果分布差异较大,即使期望接近,方差还是会很大,导致估计的不准确。

四,自适应的KL散度惩罚(Adaptive KL Penalty Coefficient)

刚刚说了采样分布差异不宜过大,为了保障分布一致,所以我们又引入了一个KL散度惩罚

$J(\theta) = E_{\tau \sim p_{\mu}(\tau)}[\sum_{t=0}^T \rho_t A_t \log \pi_{\theta}(a_t|s_t)] - \beta D_{KL}(\pi_{\theta}(a_t|s_t)||\pi_{\mu}(a_t|s_t))] $

公式如上,−βDKL​(πθ​(at​∣st​)∣∣πμ​(at​∣st​)):这部分是KL散度的惩罚项,用于控制策略更新的幅度。它表示当前策略函数θ与策略函数μ之间的KL散度。通过最小化这一项,可以限制策略更新的幅度,以保持策略的连续性。

这里Adaptive指的是β的学习,但是控制分布差异是个困难的活儿,β权重太大,会导致两个分布的差异过小,没啥变化从而学习不到内容,β权重太小又会导致分布差异控制失效,分布差异过大导致模型学习不准。

五,约束问题(Clipped Surrogate Objective)

为了解决分布控制,PPO2放弃了KL散度,引入了(Clipped Surrogate Objective),这也是"Proximal Policy Optimization Algorithms"中,"Proximal"一词的由来。近端(Proximal)优化方法是一种通过在目标函数中引入正则化项或约束来限制参数更新的方法。

具体公式如下

J^{CLIP}(\theta) = E_{\tau \sim p_{\theta_{old}}(\tau)}[\min(r_t(\theta)A_t, clip(r_t(\theta), 1-\epsilon, 1+\epsilon)A_t)]

其中,rt​(θ) 是前面提到的重要性采样比率,用于校正在旧策略下采样得到的数据。At​表示优势函数,用于衡量在状态st​下采取动作at​相对于平均水平的优劣程度。clip(x,a,b)表示将x限制在区间[a,b]内。直观上也比较好理解,如果分布符合我们希望,要多采样一点,即rt​(θ) 大一些但是不能超过1+ϵ。如果分布不符合我们希望,要少采样一点,即rt​(θ) 小一些但是不能小于1-ϵ。从而把r(θ)限定到了一个范围,避免了上一段β调节过度的问题。

最后把约束和上面policy优化结合,得到公式

$J(\theta) = E_{\tau \sim p_{\mu}(\tau)}[\sum_{t=0}^T \rho_t A_t \log \pi_{\theta}(a_t|s_t)] + J^{CLIP}(\theta) $

其中,J(θ)表示目标函数,τ表示一个轨迹,pμ​(τ)表示根据策略函数μ生成的轨迹的概率分布,At​表示在状态st​下选择动作at​后获得的优势值,ρt​是重要性采样比率,用于校正在其他策略下采样得到的数据,πθ​(at​∣st​)表示在状态st​下选择动作at​的概率。

六, 完整代码可以参考:

GitHub - Pillars-Creation/ChatGLM-RLHF-LoRA-RM-PPO: ChatGLM-6B添加了RLHF的实现,以及部分核心代码的逐行讲解 ,实例部分是做了个新闻短标题的生成

你可能感兴趣的:(算法,机器学习,人工智能,AIGC)