Policy Gradient策略梯度(PG),是一种基于策略的强化学习算法,不少帖子会讲到从基于值的算法(Q-learning/DQN/Saras)到基于策略的算法难以理解,我的理解是两者是完全两套思路,在学习一种的时候先不要考虑另一种,更容易接受算法基本思想,了解了算法原理推导过程之后再比较两者不同之处那么更容易理解了
❀策略执行
Policy Gradient算法是学习策略概率密度函数 π ( a ∣ s ) \pi(a|s) π(a∣s),它表示当前状态 s s s下执行动作 a a a的概率,策略执行的时候根据 π ( a ∣ s ) \pi(a|s) π(a∣s)抽样一个动作 a a a,这里容易混淆的地方是,抽样得到的动作 a a a不一定是概率最大的,某一次抽样的结果是随机的,随机性服从的是 π ( a ∣ s ) \pi(a|s) π(a∣s)策略概率密度
这里回顾基于值的算法(Q-learning/DQN/Saras)学习动作价值函数 Q ( s , a ) Q(s,a) Q(s,a),策略执行的时候一般采用epsilon-greedy策略,也就是选择最大 Q Q Q值对应的动作,是不是理解了Q-learning和Policy Gradient的一个区别
❀策略学习
Policy Gradient算法的思想是先将策略表示成一个和奖励有关的连续函数,然后用连续函数的优化方法去寻找最优的策略,优化目标是最大化连续函数,最常用的是优化方法是梯度上升法(与最小化loss的梯度下降相对)。
巧妙之处在于,它利用reward奖励直接对选择动作的可能性进行增强和减弱,好的动作会被增加下一次被选中的概率,不好的动作会被减弱下次被选中的概率。
这是Policy Gradient和Q-learning算法的又一次区别,Q-learning算法每步都可以更新Q值,更新是基于梯度下降的, 它的loss是TD Target,也就是当前状态下执行动作的Q值(完全是估计的),与假设再执行一步之后Q值(一部分是真实观测)的差值。
Policy Gradient的目标函数有以下三种情况
(1)最简单的优化目标就是初始状态收获的期望
(2)但是有的问题是没有明确的初始状态的,那么我们的优化目标可以定义平均价值
(3)或者定义为每一时间步的平均奖励
已经存在Q-learning/DQN/Saras这样基于值的好用的算法,为什么需要Policy Gradient呢?基于值的算法不能处理连续动作,对于高维离散动作,Q-learning更新每个值也需要大量的时间,是不可行的
Policy Gradient算法的优势和劣势总结如下
优势:
劣势:
Policy Gradient算法原理伪代码如下,采用的目标函数是上面讲到的第(1)种形式——最简单的优化目标就是初始状态收获的期望
接下来推导伪代码中的 log π θ ( s t , a t ) \log\pi_\theta(s_t, a_t) logπθ(st,at)是如何得来的
首先,将状态价值函数 V V V写成动作价值函数 Q Q Q关于动作 A A A的期望,再对状态价值函数对S做积分得到目标函数 J ( θ ) = E s [ V ( S ; θ ) ] J(\boldsymbol{\theta})=\mathbb{E_s}[V(S;\boldsymbol{\theta})] J(θ)=Es[V(S;θ)],可以理解成上述提到的第(2)种形式,此时目标函数只和策略网络参数 θ \boldsymbol{\theta} θ有关
训练策略网络采用的是梯度上升算法,下面推导出状态价值函数的导数 ∂ V ( s ; θ ) ∂ θ \frac{\partial V(s;\boldsymbol{\theta})}{\partial \boldsymbol{\theta}} ∂θ∂V(s;θ)是策略梯度
下面推导策略梯度计算方法,Form1适用于离散动作,Form2适用于连续动作
这样得到了策略梯度的两种形式,如下面所示
最后,Policy Gradient算法流程,它和图1流程的区别在于,图1采用REINFORCE算法估计动作价值Q
这里分享条理非常清晰的两份讲解
https://zhuanlan.zhihu.com/p/165439436
https://blog.csdn.net/qq_30615903/article/details/80747380
另外一份博客讲解着重列举示意示例,有助于形象化理解
https://zhuanlan.zhihu.com/p/110881517
大神老师的讲解
下一份博客分享将会介绍Policy Gradient算法操作流程中,另一种估计动作价值函数 Q Q Q的方法,也就是用Actor网络估计,动作网络和策略网络同时使用的方法称之为Actor-Critic(AC)算法