【强化学习】Sarsa+Sarsa-lambda(Sarsa(λ))算法详解

【强化学习】Sarsa+Sarsa-lambda(Sarsa(λ))算法详解

Sarsa算法的决策部分和Q-learning相同,所以下面的内容依然会基于上片Qlearning的公式推导。由于与Qlearning极大程度相似所以不会花太大的篇幅去说明。
本文图片素材引自莫烦老师的教学视频,笔者也是从新手看着莫烦老师的视频一步一步学习的。文章旨在记录分享和自己的理解。
https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/3-1-A-sarsa/

1、算法思想

Sarsa算法的的决策部分与Qlearning相同,都是通过Q表的形式进行决策,在 Q 表中挑选值较大的动作值施加在环境中来换取奖惩,也就是根据计算出来的Q值来作为选取动作的依据,两者不同的是行为更新准则是有差异的。Sarsa不会去选取他估计出来的最大Q估计值,而是直接选取估计出来的Q值。

2、行为更新

举个例子,我们会经历正在写作业的状态 s1, 然后再挑选一个带来最大潜在奖励的动作 a2, 这样我们就到达了 继续写作业状态 s2, 而在这一步, 如果你用的是 Q learning, 你会观看一下在 s2 上选取哪一个动作会带来最大的奖励, 但是在真正要做决定时, 却不一定会选取到那个带来最大奖励的动作, Q-learning 在这一步只是估计了一下接下来的动作值. 而 Sarsa 是实践派, 他说到做到, 在 s2 这一步估算的动作也是接下来要做的动作. 所以 Q(s1, a2) 现实的计算值, 我们也会稍稍改动, 去掉maxQ, 取而代之的是在 s2 上我们实实在在***选取的 a2 的 Q 值***. 最后像 Q learning 一样, 求出现实和估计的差距 并更新 Q 表里的 Q(s1, a2)

我们可以通过下面的公式来观察出他们之间的区别

Q-Learning的更新公式

Q k + 1 ∗ ( s ) ← ∑ s ′ P ( s ′ ∣ s , a ) ( R ( s , a , s ′ ) + γ m a x a ′ Q ∗ ( s ′ , a ′ ) ) Q^∗_{k+1}(s)←∑_{s′}P(s′|s,a)(R(s,a,s′)+γmax_{a′}Q^∗(s′,a′)) Qk+1(s)sP(ss,a)(R(s,a,s)+γmaxaQ(s,a))

Sarsa的更新公式

Q k + 1 ∗ ( s ) ← ∑ s ′ P ( s ′ ∣ s , a ) ( R ( s , a , s ′ ) + γ Q ∗ ( s ′ , a ′ ) ) Q^∗_{k+1}(s)←∑_{s′}P(s′|s,a)(R(s,a,s′)+γQ^∗(s′,a′)) Qk+1(s)sP(ss,a)(R(s,a,s)+γQ(s,a))

【强化学习】Sarsa+Sarsa-lambda(Sarsa(λ))算法详解_第1张图片

Sarsa 是说到做到型, 所以我们也叫他 on-policy, 在线学习, 学着自己在做的事情.Sarsa相当保守,他会选择离危险远远的,拿到宝藏是次要的, 保住自己的小命才是王道. 这就是使用 Sarsa 方法的不同之处.
Q learning 是说到但并不一定做到,所以它也叫作 Off-policy,离线学习.而因为有了 maxQ,Q-learning 也是一个特别勇敢的算法.永远都会选择最近的一条通往成功的道路, 不管这条路会有多危险.

3、Sarsa-lamda: Sarsa 的一种提速方法

Sarsa 是一种单步更新法, 也就是 Sarsa(0), 因为他等走完这一步以后直接更新行为准则. 如果延续这种想法, 走完这步, 再走一步, 然后再更新, 我们可以叫他 Sarsa(1). 同理, 如果等待回合完毕我们一次性再更新呢, 比如这回合我们走了 n 步, 那我们就叫 Sarsa(n). 为了统一这样的流程, 我们就有了一个 lambda 值来代替我们想要选择的步数。 这也就是 Sarsa(λ)的由来,Sarsa 和 Qlearning 都是每次获取到奖励reward后只更新获取到 reward 的前一步,那么Sarsa(λ)就是更新获取到 reward 的前 λ 步. λ 在 [0, 1] 之间取值,
当 lambda = 0, Sarsa-lambda 就是 Sarsa单步更新, 只更新获取到 reward 前经历的最后一步。如果 lambda = 1, Sarsa-lambda就变成了回合更新,更新的是获取到 reward 前所有经历的步,对所有步更新的力度都是一样. 当 lambda 在 0 和 1 之间, 取值越大, 获得奖励大的步更新力度越大. 这样我们就不用受限于单步更新的每次只能更新最近的一步, 我们可以更有效率的更新所有相关步了。【强化学习】Sarsa+Sarsa-lambda(Sarsa(λ))算法详解_第2张图片
Sarsa(λ)算法伪代码
【强化学习】Sarsa+Sarsa-lambda(Sarsa(λ))算法详解_第3张图片

4、算法实现

# off-policy
class QLearningTable(RL):
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

    def learn(self, s, a, r, s_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            q_target = r + self.gamma * self.q_table.loc[s_, :].max()  # next state is not terminal
        else:
            q_target = r  # 下一个状态就结束了
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # update


# on-policy
class SarsaTable(RL):

    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)    # 表示继承关系

    def learn(self, s, a, r, s_, a_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            q_target = r + self.gamma * self.q_table.loc[s_, a_]  # q_target 基于选好的 a_ 而不是 Q(s_) 的最大值
        else:
            q_target = r  # 如果 s_ 是终止符
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # 更新 q_table

# backward eligibility traces
class SarsaLambdaTable(RL):
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9):
        super(SarsaLambdaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

        # backward view, eligibility trace.
        self.lambda_ = trace_decay
        self.eligibility_trace = self.q_table.copy()

    def check_state_exist(self, state):
        if state not in self.q_table.index:
            # append new state to q table
            to_be_append = pd.Series(
                    [0] * len(self.actions),
                    index=self.q_table.columns,
                    name=state,
                )
            self.q_table = self.q_table.append(to_be_append)

            # also update eligibility trace
            self.eligibility_trace = self.eligibility_trace.append(to_be_append)

    def learn(self, s, a, r, s_, a_):
        # 这部分和 Sarsa 一样
        self.check_state_exist(s_)
        q_predict = self.q_table.ix[s, a]
        if s_ != 'terminal':
            q_target = r + self.gamma * self.q_table.ix[s_, a_]
        else:
            q_target = r
        error = q_target - q_predict

        # 这里开始不同:
        # 对于经历过的 state-action, 我们让他+1, 证明他是得到 reward 路途中不可或缺的一环
        self.eligibility_trace.ix[s, a] += 1

        # Q table 更新
        self.q_table += self.lr * error * self.eligibility_trace

        # 随着时间衰减 eligibility trace 的值, 离获取 reward 越远的步, 他的"不可或缺性"越小
        self.eligibility_trace *= self.gamma * self.lambda_

你可能感兴趣的:(强化学习)