【强化学习】16 ——PPO(Proximal Policy Optimization)

文章目录

  • 前言
    • TRPO的不足
    • PPO特点
  • PPO-惩罚
  • PPO-截断
  • 优势函数估计
  • 算法伪代码
  • PPO 代码实践
  • 参考

前言

TRPO 算法在很多场景上的应用都很成功,但是我们也发现它的计算过程非常复杂,每一步更新的运算量非常大。于是,TRPO 算法的改进版——PPO 算法在 2017 年被提出,PPO 基于 TRPO 的思想,但是其算法实现更加简单。并且大量的实验结果表明,与 TRPO 相比,PPO 能学习得一样好(甚至更快),这使得 PPO 成为非常流行的强化学习算法。

TRPO的不足

回顾一下TRPO算法 θ ′ ← arg ⁡ max ⁡ θ ′ ∑ t E s t ∼ p θ ( s t ) [ E a t ∼ π θ ( a t ∣ s t ) [ π θ ′ ( a t ∣ s t ) π θ ( a t ∣ s t ) γ t A π θ ( s t , a t ) ] ] s u c h   t h a t   E s t ∼ p ( s t ) [ D K L ( π θ ′ ( a t ∣ s t ) ∥ π θ ( a t ∣ s t ) ) ] ≤ ϵ \begin{aligned} &\theta'\leftarrow\arg\max_{\theta'}\sum_t\mathbb{E}_{s_t\sim p_\theta(s_t)}[\mathbb{E}_{a_t\sim\pi_\theta(a_t|s_t)}[\frac{\pi_{\theta'}(a_t|s_t)}{\pi_\theta(a_t|s_t)}\gamma^tA^{\pi_\theta}(s_t,a_t)]] \\ &\mathrm{such~that~}\mathbb{E}_{s_t\sim p(s_t)}[D_{KL}(\pi_{\theta^{\prime}}(a_t|s_t)\parallel\pi_\theta(a_t|s_t))]\leq\epsilon \end{aligned} θargθmaxtEstpθ(st)[Eatπθ(atst)[πθ(atst)πθ(atst)γtAπθ(st,at)]]such that Estp(st)[DKL(πθ(atst)πθ(atst))]ϵ
TRPO算法主要存在以下不足:

  • 第一个问题是重要性采样的通病,重要性采样中的比例 π θ ′ ( a t ∣ s t ) π θ ( a t ∣ s t ) \frac{\pi_{\theta'}(a_t|s_t)}{\pi_\theta(a_t|s_t)} πθ(atst)πθ(atst)会带来较大的方差。
  • 求解约束优化问题比较困难
  • 近似求解会带来误差

TRPO 使用泰勒展开近似、共轭梯度、线性搜索等方法直接求解。PPO 的优化目标与 TRPO 相同,但 PPO 用了一些相对简单的方法来求解。具体来说,PPO 有两种形式,一是 PPO-惩罚(PPO-Penalty),二是 PPO-截断(PPO-Clip),我们接下来对这两种形式进行介绍。

PPO特点

  • 是一种on-policy的算法
  • 可以用于连续或离散的动作空间
  • Open AI Spinning Up 中的PPO可以达到并行运行的效果

PPO-惩罚

PPO-惩罚(PPO-Penalty)用拉格朗日乘数法直接将 KL 散度的限制放进了目标函数中,这就变成了一个无约束的优化问题,在迭代的过程中不断更新 KL 散度前的系数。即: arg ⁡ max ⁡ θ E s ∼ ν E a ∼ π θ k ( ⋅ ∣ s ) [ π θ ( a ∣ s ) π θ k ( a ∣ s ) A π θ k ( s , a ) − β D K L [ π θ k ( ⋅ ∣ s ) , π θ ( ⋅ ∣ s ) ] ] \arg\max_{\theta}\mathbb{E}_{s\sim\nu}\mathbb{E}_{a\sim\pi_{\theta_k}(\cdot|s)}\left[\frac{\pi_\theta(a|s)}{\pi_{\theta_k}(a|s)}A^{\pi_{\theta_k}}(s,a)-\beta D_{KL}[\pi_{\theta_k}(\cdot|s),\pi_\theta(\cdot|s)]\right] argθmaxEsνEaπθk(s)[πθk(as)πθ(as)Aπθk(s,a)βDKL[πθk(s),πθ(s)]]

d k = D K L ν π θ k ( π θ k , π θ ) d_k=D_{KL}^{\nu^{\pi_{\theta_k}}}\left(\pi_{\theta_k},\pi_{\theta}\right) dk=DKLνπθk(πθk,πθ),则 β \beta β的更新规则如下:

  • 如果 d k < δ / 1.5 d_k<\delta/1.5 dk<δ/1.5,那么 β k + 1 = β k / 2 \beta_{k+1}=\beta_k/2 βk+1=βk/2
  • 如果 d k > δ × 1.5 d_k>\delta\times1.5 dk>δ×1.5,那么 β k + 1 = β k × 2 \beta_{k+1}=\beta_k\times2 βk+1=βk×2
  • 否则 β k + 1 = β k \beta_{k+1}=\beta_k βk+1=βk

其中, δ \delta δ是事先设定的一个超参数,用于限制学习策略和之前一轮策略的差距。另外,这里1.5和2是经验参数,算法效能和它们并不是很敏感。

PPO-截断

PPO 的另一种形式 PPO-截断(PPO-Clip)更加直接,它在目标函数中进行限制,以保证新的参数和旧的参数的差距不会太大。TRPO的优化目标如下: L C P I ( θ ) = E ^ t [ π θ ( a t ∣ s t ) π θ o l d ( a t ∣ s t ) A ^ t ] = E ^ t [ r t ( θ ) A ^ t ] . \begin{aligned}L^{CPI}(\theta)&=\hat{\mathbb{E}}_t\bigg[\frac{\pi_\theta(a_t\mid s_t)}{\pi_{\theta_{\mathrm{old}}}(a_t\mid s_t)}\hat{A}_t\bigg]=\hat{\mathbb{E}}_t\bigg[r_t(\theta)\hat{A}_t\bigg].\end{aligned} LCPI(θ)=E^t[πθold(atst)πθ(atst)A^t]=E^t[rt(θ)A^t]. L C L I P ( θ ) = E ^ t [ min ⁡ ( r t ( θ ) A ^ t , clip ⁡ ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) A ^ t ) ] \begin{aligned}L^{CLIP}(\theta)&=\hat{\mathbb{E}}_t\Big[\min(r_t(\theta)\hat{A}_t,\operatorname{clip}(r_t(\theta),1-\epsilon,1+\epsilon)\hat{A}_t)\Big]\end{aligned} LCLIP(θ)=E^t[min(rt(θ)A^t,clip(rt(θ),1ϵ,1+ϵ)A^t)]

其中, C P I CPI CPI代表了保守的策略迭代(conservative policy iteration)。新旧策略的采样比例为 r t ( θ )   =   π θ ( a t ∣ s t ) π θ o l d ( a t ∣ s t ) r_t(\theta)~=~\frac{\pi_\theta(a_t\mid s_t)}{\pi_{\theta_{\mathrm{old}}}(a_t\mid s_t)} rt(θ) = πθold(atst)πθ(atst) clip ⁡ ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) \operatorname{clip}(r_t(\theta),1-\epsilon,1+\epsilon) clip(rt(θ),1ϵ,1+ϵ)表示对比例 r t ( θ ) r_t(\theta) rt(θ)截断到 [ 1 − ϵ , 1 + ϵ ] [1-\epsilon,1+\epsilon] [1ϵ,1+ϵ]之内, ϵ \epsilon ϵ为超参数,表示进行截断(clip)的范围。最后 L C L I P ( θ ) L^{CLIP}(\theta) LCLIP(θ)对未被截断和截断部分取最小值,因此,可以剔除比例方差带来的不良影响。当 r t ( θ ) = 1 r_t(\theta)=1 rt(θ)=1时, L C L I P ( θ ) = L C P I ( θ ) L^{CLIP}(\theta)=L^{CPI}(\theta) LCLIP(θ)=LCPI(θ)。如下图所示,若 A > 0 A>0 A>0,说明这个动作的价值高于平均,因此会提升其比例,但不会超过 1 + ϵ 1+\epsilon 1+ϵ;若 A < 0 A<0 A<0,说明这个动作的价值低于平均,因此会降低其比例,但不会低于 1 − ϵ 1-\epsilon 1ϵ
【强化学习】16 ——PPO(Proximal Policy Optimization)_第1张图片

下图是一个更为直观的表示:
【强化学习】16 ——PPO(Proximal Policy Optimization)_第2张图片

优势函数估计

大多数计算方差减小的优势函数的估计方法都会利用学习到的状态值函数 V ( s ) V(s) V(s),如广义优势估计(generalized advantage estimation)或有限视野估计(finite-horizon estimators)。如果使用在策略和值函数之间共享参数的神经网络架构,我们必须使用结合策略替代和值函数误差项的损失函数,通过在目标中添加熵奖励( entropy bonus)来确保充分的探索可以来有效解决这一问题。综上,在每次迭代中所使用的目标函数为: L t C L I P + V F + S ( θ ) = E ^ t [ L t C L I P ( θ ) − c 1 L t V F ( θ ) + c 2 S [ π θ ] ( s t ) ] , L_t^{CLIP+VF+S}(\theta)=\hat{\mathbb{E}}_t\big[L_t^{CLIP}(\theta)-c_1L_t^{VF}(\theta)+c_2S[\pi_\theta](s_t)\big], LtCLIP+VF+S(θ)=E^t[LtCLIP(θ)c1LtVF(θ)+c2S[πθ](st)],

其中, c 1 , c 2 c_1,c_2 c1,c2是系数, S S S为熵奖励, L t V F L_t^{VF} LtVF是平方差 ( V θ ( s t ) − V t t a r g ) 2 \left(V_\theta(s_t)-V_t^\mathrm{targ}\right)^2 (Vθ(st)Vttarg)2

对于优势函数估计,PPO采用多( T T T)步时序差分 A t ^ = − V ( s t ) + r t + γ r t + 1 + ⋯ + γ T − t + 1 r T − 1 + γ T − t V ( s T ) \hat{A_t}=-V(s_t)+r_t+\gamma r_{t+1}+\cdots+\gamma^{T-t+1}r_{T-1}+\gamma^{T-t}V(s_T) At^=V(st)+rt+γrt+1++γTt+1rT1+γTtV(sT)PPO采用的方法是GAE方法的截断版本(当 λ = 1 \lambda=1 λ=1时,就和上式相等)。
A ^ t = δ t + ( γ λ ) δ t + 1 + ⋯ + ⋯ + ( γ λ ) T − t + 1 δ T − 1 , w h e r e δ t = r t + γ V ( s t + 1 ) − V ( s t ) \begin{aligned}\hat A_t&=\delta_t+(\gamma\lambda)\delta_{t+1}+\cdots+\cdots+(\gamma\lambda)^{T-t+1}\delta_{T-1},\\\mathrm{where}\quad\delta_t&=r_t+\gamma V(s_{t+1})-V(s_t)\end{aligned} A^twhereδt=δt+(γλ)δt+1+++(γλ)Tt+1δT1,=rt+γV(st+1)V(st)

算法伪代码

【强化学习】16 ——PPO(Proximal Policy Optimization)_第3张图片

  • 在每次迭代中,并行个actor收集步经验数据
  • 计算每步的 A t ^ \hat{A_t} At^ L C L I P L^{CLIP} LCLIP构成mini-batch(优化方法可用SGD或Adam)
  • 更新参数,并更新 θ o l d ← θ \theta_{old}\leftarrow\theta θoldθ

【强化学习】16 ——PPO(Proximal Policy Optimization)_第4张图片

PPO 代码实践

大量实验表明,PPO-截断总是比 PPO-惩罚表现得更好。因此下面我们专注于 PPO-截断的代码实现。

import gymnasium as gym
import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
import util

class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1)

# 输入是某个状态,输出则是状态的价值。
class ValueNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)

class PPO:
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma,
                lambda_, clip_param, train_epochs, device, numOfEpisodes, env):
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.critic = ValueNet(state_dim, hidden_dim).to(device)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.gamma = gamma
        self.device = device
        self.env = env
        self.numOfEpisodes = numOfEpisodes
        self.lambda_ = lambda_
        # PPO中截断范围的参数
        self.clip_param = clip_param
        # 一条序列的数据用来训练轮数
        self.train_epochs = train_epochs

    def take_action(self, state):
        states = torch.FloatTensor(np.array([state])).to(self.device)
        probs = self.actor(states)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    def cal_advantage(self, gamma, lambda_, td_delta):
        td_delta = td_delta.detach().numpy()
        advantages = []
        advantage = 0.0
        for delta in reversed(td_delta):
            advantage = gamma * lambda_ * advantage + delta
            advantages.append(advantage)
        advantages.reverse()
        return torch.FloatTensor(np.array(advantages))

    def update(self, transition_dict):
        states = torch.tensor(np.array(transition_dict['states']), dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(np.array(transition_dict['next_states']), dtype=torch.float).to(self.device)
        terminateds = torch.tensor(transition_dict['terminateds'], dtype=torch.float).view(-1, 1).to(self.device)
        truncateds = torch.tensor(transition_dict['truncateds'], dtype=torch.float).view(-1, 1).to(self.device)
        td_target = rewards + self.gamma * self.critic(next_states) * (1 - terminateds + truncateds)
        td_delta = td_target - self.critic(states)
        advantage = self.cal_advantage(self.gamma, self.lambda_, td_delta.cpu()).to(self.device)
        old_log_probs  = torch.log(self.actor(states).gather(1, actions)).detach()
        for _ in range(self.train_epochs):
            log_probs = torch.log(self.actor(states).gather(1, actions))
            ratio = torch.exp(log_probs - old_log_probs)
            L_CPI = ratio * advantage
            L_CLIP = torch.min(L_CPI, torch.clamp(ratio, 1 - self.clip_param, 1 + self.clip_param) * advantage)
            actor_loss = torch.mean(-L_CLIP)
            critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
            self.actor_optimizer.zero_grad()
            self.critic_optimizer.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            self.actor_optimizer.step()
            self.critic_optimizer.step()

    def PPOrun(self):
        returnList = []
        for i in range(10):
            with tqdm(total=int(self.numOfEpisodes / 10), desc='Iteration %d' % i) as pbar:
                for episode in range(int(self.numOfEpisodes / 10)):
                    # initialize state
                    state, info = self.env.reset()
                    terminated = False
                    truncated = False
                    episodeReward = 0
                    transition_dict = {
                        'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'terminateds': [], 'truncateds': []}
                    # Loop for each step of episode:
                    while 1:
                        action = self.take_action(state)
                        next_state, reward, terminated, truncated, info = self.env.step(action)
                        transition_dict['states'].append(state)
                        transition_dict['actions'].append(action)
                        transition_dict['next_states'].append(next_state)
                        transition_dict['rewards'].append(reward)
                        transition_dict['terminateds'].append(terminated)
                        transition_dict['truncateds'].append(truncated)
                        state = next_state
                        episodeReward += reward
                        if terminated or truncated:
                            break
                    self.update(transition_dict)
                    returnList.append(episodeReward)
                    if (episode + 1) % 10 == 0:  # 每10条序列打印一下这10条序列的平均回报
                        pbar.set_postfix({
                            'episode':
                                '%d' % (self.numOfEpisodes / 10 * i + episode + 1),
                            'return':
                                '%.3f' % np.mean(returnList[-10:])
                        })
                    pbar.update(1)
        return returnList

超参数参考设置:

    agent = PPO(state_dim=env.observation_space.shape[0],
                hidden_dim=256,
                action_dim=2,
                actor_lr=1e-3,
                critic_lr=1e-2,
                gamma=0.99,
                lambda_=0.95,
                clip_param=0.2,
                train_epochs=8,
                device=device,
                numOfEpisodes=1000,
                env=env)

结果:
【强化学习】16 ——PPO(Proximal Policy Optimization)_第5张图片
可见,PPO算法收敛速度快,表现十分优秀。

参考

[1] 伯禹AI
[2] https://www.davidsilver.uk/teaching/
[3] 动手学强化学习
[4] Reinforcement Learning
[5] SCHULMAN J, FILIP W, DHARIWAL P, et al. Proximal policy optimization algorithms [J]. Machine Learning, 2017.

你可能感兴趣的:(强化学习,算法,机器学习,人工智能,强化学习)