用 δ t = r t + γ V ( s t + 1 ) − V ( s t ) \delta_t=r_t+\gamma V(s_{t+1})-V(s_t) δt=rt+γV(st+1)−V(st)表示时序差分误差,公式中的 V V V表示一个已经学习的状态价值函数,根据多步时序差分的思想,有: A t ( 1 ) = δ t = − V ( s t ) + r t + γ V ( s t + 1 ) A t ( 2 ) = δ t + γ δ t + 1 = − V ( s t ) + r t + γ r t + 1 + γ 2 V ( s t + 2 ) A t ( 3 ) = δ t + γ δ t + 1 + γ 2 δ t + 2 = − V ( s t ) + r t + γ r t + 1 + γ 2 r t + 2 + γ 3 V ( s t + 3 ) ⋮ ⋮ A t ( k ) = ∑ l = 0 k − 1 γ l δ t + l = − V ( s t ) + r t + γ r t + 1 + … + γ k − 1 r t + k − 1 + γ k V ( s t + k ) \begin{array}{ll} A_t^{(1)}=\delta_t & =-V\left(s_t\right)+r_t+\gamma V\left(s_{t+1}\right) \\ A_t^{(2)}=\delta_t+\gamma \delta_{t+1} & =-V\left(s_t\right)+r_t+\gamma r_{t+1}+\gamma^2 V\left(s_{t+2}\right) \\ A_t^{(3)}=\delta_t+\gamma \delta_{t+1}+\gamma^2 \delta_{t+2} & =-V\left(s_t\right)+r_t+\gamma r_{t+1}+\gamma^2 r_{t+2}+\gamma^3 V\left(s_{t+3}\right) \\ \vdots & \vdots \\ A_t^{(k)}=\sum_{l=0}^{k-1} \gamma^l \delta_{t+l} & =-V\left(s_t\right)+r_t+\gamma r_{t+1}+\ldots+\gamma^{k-1} r_{t+k-1}+\gamma^k V\left(s_{t+k}\right) \end{array} At(1)=δtAt(2)=δt+γδt+1At(3)=δt+γδt+1+γ2δt+2⋮At(k)=∑l=0k−1γlδt+l=−V(st)+rt+γV(st+1)=−V(st)+rt+γrt+1+γ2V(st+2)=−V(st)+rt+γrt+1+γ2rt+2+γ3V(st+3)⋮=−V(st)+rt+γrt+1+…+γk−1rt+k−1+γkV(st+k)
简单解释一下上面的公式,根据 δ \delta δ的定义,有: δ t = r t + γ V ( s t + 1 ) − V ( s t ) \delta_t=r_t+\gamma V(s_{t+1})-V(s_t) δt=rt+γV(st+1)−V(st) δ t + 1 = r t + 1 + γ V ( s t + 2 ) − V ( s t + 1 ) \delta_{t+1}=r_{t+1} + \gamma V(s_{t+2}) -V(s_{t+1}) δt+1=rt+1+γV(st+2)−V(st+1)
根据这两个公式,可以得到: A t ( 2 ) = δ t + γ δ t + 1 = r t + γ V ( s t + 1 ) − V ( s t ) + γ ( r t + 1 + γ V ( s t + 2 ) − V ( s t + 1 ) ) = r t + γ V ( s t + 1 ) − V ( s t ) + γ r t + 1 + γ 2 V ( s t + 2 ) − γ V ( s t + 1 ) = − V ( s t ) + r t + γ r t + 1 + γ 2 V ( s t + 2 ) \begin{align*} A_t^{(2)} = \delta_{t} + \gamma \delta_{t+1} &= r_t+\gamma V(s_{t+1})-V(s_t) + \gamma(r_{t+1} + \gamma V(s_{t+2}) -V(s_{t+1}))\\ &= r_t+\gamma V(s_{t+1})-V(s_t) + \gamma r_{t+1} + \gamma ^{2} V(s_{t+2}) - \gamma V(s_{t+1})\\ &= -V(s_t) + r_t + \gamma r_{t+1} + \gamma ^{2} V(s_{t+2}) \end{align*} At(2)=δt+γδt+1=rt+γV(st+1)−V(st)+γ(rt+1+γV(st+2)−V(st+1))=rt+γV(st+1)−V(st)+γrt+1+γ2V(st+2)−γV(st+1)=−V(st)+rt+γrt+1+γ2V(st+2)其余部分的 A t ( k ) A^{(k)}_t At(k)可以通过类似的方法推导得到。
GAE的原理是将这些不同步数的优势估计进行指数加权平均,这里先通过简单的例子介绍一下指数加权平均。
假定现在需要用指数加权平均计算100天的平均温度值: 22 , 24 , 25 , 27 , 33 , 24 , . . . , 25 22,24,25,27,33,24,...,25 22,24,25,27,33,24,...,25。指数加权平均的计算公式为: v t = λ v t − 1 + ( 1 − λ ) θ t v_t=\lambda v_{t-1} + (1-\lambda)\theta_t vt=λvt−1+(1−λ)θt公式中的 v t v_t vt表示到第 t t t天的平均温度值, θ t \theta_t θt表示第 t t t天的温度值, λ \lambda λ表示可调节的超参数值。
假定 λ = 0.9 \lambda=0.9 λ=0.9,通过指数加权平均得到的平均温度如下: v 100 = 0.9 ∗ v 99 + 0.1 ∗ θ 100 v 99 = 0.9 ∗ v 98 + 0.1 ∗ θ 99 v 98 = 0.9 ∗ v 97 + 0.1 ∗ θ 98 ⋮ v 1 = 0.9 ∗ v 0 + 0.1 ∗ θ 1 v 0 = 0.1 ∗ θ 0 \begin{align*} v_{100} &= 0.9*v_{99} + 0.1*\theta_{100}\\ v_{99}&= 0.9 * v_{98} + 0.1*\theta_{99}\\ v_{98}&= 0.9 * v_{97} + 0.1*\theta_{98} \\ \vdots \\ v_{1}&= 0.9 * v_{0} + 0.1*\theta_{1} \\ v_{0}&= 0.1*\theta_{0} \\ \end{align*} v100v99v98⋮v1v0=0.9∗v99+0.1∗θ100=0.9∗v98+0.1∗θ99=0.9∗v97+0.1∗θ98=0.9∗v0+0.1∗θ1=0.1∗θ0
将上面公式进行转换可以得到 v 100 = 0.1 ∗ θ 100 + 0.1 ∗ 0.9 ∗ θ 99 + 0.1 ∗ 0. 9 2 ∗ θ 99 + . . . + 0.1 ∗ 0. 9 100 ∗ θ 0 = 0.1 ∗ ( θ 100 + 0.9 ∗ θ 99 + 0. 9 2 ∗ θ 98 + . . . + 0. 9 100 ∗ θ 0 ) = ( 1 − λ ) ∗ ( θ 100 + λ ∗ θ 99 + λ 2 ∗ θ 98 + . . . + λ 100 ∗ θ 0 ) \begin{align*} v_{100} &= 0.1*\theta_{100} + 0.1 * 0.9*\theta_{99} + 0.1 * 0.9^2*\theta_{99} + ... + 0.1 * 0.9^{100}*\theta_0 \\ &=0.1 * (\theta_{100} + 0.9 * \theta_{99} + 0.9^2 * \theta_{98} + ... + 0.9 ^{100}*\theta_0) \\ &=(1-\lambda) * (\theta_{100} + \lambda* \theta_{99} + \lambda^2 * \theta_{98} + ... + \lambda^{100}*\theta_0) \end{align*} v100=0.1∗θ100+0.1∗0.9∗θ99+0.1∗0.92∗θ99+...+0.1∗0.9100∗θ0=0.1∗(θ100+0.9∗θ99+0.92∗θ98+...+0.9100∗θ0)=(1−λ)∗(θ100+λ∗θ99+λ2∗θ98+...+λ100∗θ0)
将其类比到GAE中,可以得到: A t G A E = ( 1 − λ ) ( A t ( 1 ) + λ A t ( 2 ) + λ 2 A t ( 3 ) + ⋯ ) = ( 1 − λ ) ( δ t + λ ( δ t + γ δ t + 1 ) + λ 2 ( δ t + γ δ t + 1 + γ 2 δ t + 2 ) + ⋯ ) = ( 1 − λ ) ( δ ( 1 + λ + λ 2 + ⋯ ) + γ δ t + 1 ( λ + λ 2 + λ 3 + ⋯ ) + γ 2 δ t + 2 ( λ 2 + λ 3 + λ 4 + ⋯ ) + ⋯ ) = ( 1 − λ ) ( δ t 1 1 − λ + γ δ t + 1 λ 1 − λ + γ 2 δ t + 2 λ 2 1 − λ + ⋯ ) = ∑ l = 0 ∞ ( γ λ ) l δ t + l \begin{aligned} A_t^{G A E} & =(1-\lambda)\left(A_t^{(1)}+\lambda A_t^{(2)}+\lambda^2 A_t^{(3)}+\cdots\right) \\ & =(1-\lambda)\left(\delta_t+\lambda\left(\delta_t+\gamma \delta_{t+1}\right)+\lambda^2\left(\delta_t+\gamma \delta_{t+1}+\gamma^2 \delta_{t+2}\right)+\cdots\right) \\ & =(1-\lambda)\left(\delta\left(1+\lambda+\lambda^2+\cdots\right)+\gamma \delta_{t+1}\left(\lambda+\lambda^2+\lambda^3+\cdots\right)+\gamma^2 \delta_{t+2}\left(\lambda^2+\lambda^3+\lambda^4+\cdots\right)+\cdots\right) \\ & =(1-\lambda)\left(\delta_t \frac{1}{1-\lambda}+\gamma \delta_{t+1} \frac{\lambda}{1-\lambda}+\gamma^2 \delta_{t+2} \frac{\lambda^2}{1-\lambda}+\cdots\right) \\ & =\sum_{l=0}^{\infty}(\gamma \lambda)^l \delta_{t+l} \end{aligned} AtGAE=(1−λ)(At(1)+λAt(2)+λ2At(3)+⋯)=(1−λ)(δt+λ(δt+γδt+1)+λ2(δt+γδt+1+γ2δt+2)+⋯)=(1−λ)(δ(1+λ+λ2+⋯)+γδt+1(λ+λ2+λ3+⋯)+γ2δt+2(λ2+λ3+λ4+⋯)+⋯)=(1−λ)(δt1−λ1+γδt+11−λλ+γ2δt+21−λλ2+⋯)=l=0∑∞(γλ)lδt+l
上面公式中的 λ ∈ [ 0 , 1 ] \lambda \in [0,1] λ∈[0,1]是在GAE中引入的超参数。当 λ = 0 \lambda=0 λ=0,可以得到 A t G A E = A t ( 1 ) = δ t = r t + γ V ( s t + 1 ) − V ( s t ) A_t^{G A E}=A_t^{(1)}=\delta_t=r_t+\gamma V(s_{t+1})-V(s_t) AtGAE=At(1)=δt=rt+γV(st+1)−V(st)即只看到一步差分得到的优势值,当 λ \lambda λ趋向于1时,GAE会考虑更多步差分的平均值。
下面是一段GAE的实现代码,给定折扣系数 γ \gamma γ、GAE超参数 λ \lambda λ、回合中时间步 δ t \delta_t δt序列,根据上述公式可以进行优势估计:
# 基于 pytorch
def compute_advantage(gamma, lmbda, td_delta):
td_delta = td_delta.detach().numpy()
advantage_list = []
advantage = 0.0
for delta in td_delta[::-1]:
advantage = gamma * lmbda * advantage + delta
advantage_list.append(advantage)
advantage_list.reverse()
return torch.tensor(advantage_list, dtype=torch.float)
# 基于tensorflow
def compute_advantage(gamma, lmbda, td_delta):
td_delta = tf.stop_gradient(td_delta)
advantage_list = []
advantage = 0.0
for delta in tf.reverse(td_delta, axis=[0]):
advantage = gamma * lmbda * advantage + delta
advantage_list.append(advantage)
advantage_list.reverse()
return tf.convert_to_tensor(advantage_list, dtype=tf.float32)
参考资料:动手学强化学习