Sarsa算法的决策部分和Q-learning相同,所以下面的内容依然会基于上片Qlearning的公式推导。由于与Qlearning极大程度相似所以不会花太大的篇幅去说明。
本文图片素材引自莫烦老师的教学视频,笔者也是从新手看着莫烦老师的视频一步一步学习的。文章旨在记录分享和自己的理解。
https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/3-1-A-sarsa/
Sarsa算法的的决策部分与Qlearning相同,都是通过Q表的形式进行决策,在 Q 表中挑选值较大的动作值施加在环境中来换取奖惩,也就是根据计算出来的Q值来作为选取动作的依据,两者不同的是行为更新准则是有差异的。Sarsa不会去选取他估计出来的最大Q估计值,而是直接选取估计出来的Q值。
举个例子,我们会经历正在写作业的状态 s1, 然后再挑选一个带来最大潜在奖励的动作 a2, 这样我们就到达了 继续写作业状态 s2, 而在这一步, 如果你用的是 Q learning, 你会观看一下在 s2 上选取哪一个动作会带来最大的奖励, 但是在真正要做决定时, 却不一定会选取到那个带来最大奖励的动作, Q-learning 在这一步只是估计了一下接下来的动作值. 而 Sarsa 是实践派, 他说到做到, 在 s2 这一步估算的动作也是接下来要做的动作. 所以 Q(s1, a2) 现实的计算值, 我们也会稍稍改动, 去掉maxQ, 取而代之的是在 s2 上我们实实在在***选取的 a2 的 Q 值***. 最后像 Q learning 一样, 求出现实和估计的差距 并更新 Q 表里的 Q(s1, a2)
我们可以通过下面的公式来观察出他们之间的区别
Sarsa 是说到做到型, 所以我们也叫他 on-policy, 在线学习, 学着自己在做的事情.Sarsa相当保守,他会选择离危险远远的,拿到宝藏是次要的, 保住自己的小命才是王道. 这就是使用 Sarsa 方法的不同之处.
Q learning 是说到但并不一定做到,所以它也叫作 Off-policy,离线学习.而因为有了 maxQ,Q-learning 也是一个特别勇敢的算法.永远都会选择最近的一条通往成功的道路, 不管这条路会有多危险.
Sarsa 是一种单步更新法, 也就是 Sarsa(0), 因为他等走完这一步以后直接更新行为准则. 如果延续这种想法, 走完这步, 再走一步, 然后再更新, 我们可以叫他 Sarsa(1). 同理, 如果等待回合完毕我们一次性再更新呢, 比如这回合我们走了 n 步, 那我们就叫 Sarsa(n). 为了统一这样的流程, 我们就有了一个 lambda 值来代替我们想要选择的步数。 这也就是 Sarsa(λ)的由来,Sarsa 和 Qlearning 都是每次获取到奖励reward后只更新获取到 reward 的前一步,那么Sarsa(λ)就是更新获取到 reward 的前 λ 步. λ 在 [0, 1] 之间取值,
当 lambda = 0, Sarsa-lambda 就是 Sarsa单步更新, 只更新获取到 reward 前经历的最后一步。如果 lambda = 1, Sarsa-lambda就变成了回合更新,更新的是获取到 reward 前所有经历的步,对所有步更新的力度都是一样. 当 lambda 在 0 和 1 之间, 取值越大, 获得奖励大的步更新力度越大. 这样我们就不用受限于单步更新的每次只能更新最近的一步, 我们可以更有效率的更新所有相关步了。
Sarsa(λ)算法伪代码
# off-policy
class QLearningTable(RL):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
def learn(self, s, a, r, s_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, :].max() # next state is not terminal
else:
q_target = r # 下一个状态就结束了
self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # update
# on-policy
class SarsaTable(RL):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) # 表示继承关系
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, a_] # q_target 基于选好的 a_ 而不是 Q(s_) 的最大值
else:
q_target = r # 如果 s_ 是终止符
self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # 更新 q_table
# backward eligibility traces
class SarsaLambdaTable(RL):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9):
super(SarsaLambdaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
# backward view, eligibility trace.
self.lambda_ = trace_decay
self.eligibility_trace = self.q_table.copy()
def check_state_exist(self, state):
if state not in self.q_table.index:
# append new state to q table
to_be_append = pd.Series(
[0] * len(self.actions),
index=self.q_table.columns,
name=state,
)
self.q_table = self.q_table.append(to_be_append)
# also update eligibility trace
self.eligibility_trace = self.eligibility_trace.append(to_be_append)
def learn(self, s, a, r, s_, a_):
# 这部分和 Sarsa 一样
self.check_state_exist(s_)
q_predict = self.q_table.ix[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.ix[s_, a_]
else:
q_target = r
error = q_target - q_predict
# 这里开始不同:
# 对于经历过的 state-action, 我们让他+1, 证明他是得到 reward 路途中不可或缺的一环
self.eligibility_trace.ix[s, a] += 1
# Q table 更新
self.q_table += self.lr * error * self.eligibility_trace
# 随着时间衰减 eligibility trace 的值, 离获取 reward 越远的步, 他的"不可或缺性"越小
self.eligibility_trace *= self.gamma * self.lambda_