目录
1. 要点
2. 算法流程
3. 算法代码部分
3.1 迭代更新
3.2 思维决策代码
3.2.1 学习
这次我们用同样的迷宫例子来实现 RL 中另一种和 Qlearning 类似的算法, 叫做 Sarsa (state-action-reward-state-action). 我们从这一个简称可以了解到, Sarsa 的整个循环都将是在一个路径上, 也就是 on-policy, 下一个 state, 和下一个 action 将会变成他真正采取的 action 和 state. 和 Qlearning 的不同之处就在这. Qlearning 的下一个 state_ 和action_ 在算法更新的时候都还是不确定的 (off-policy). 而 Sarsa 的 state, action 在这次算法更新的时候已经确定好了 (on-policy).
Q-learning是off-policy的,就是可以看着别人玩,自己学着别人再玩;Sarsa是on-policy的算法,自身走到哪一步就学习哪一步,所以Sarsa只能从自身的经验学。
整个算法还是一直不断更新 Q table 里的值, 然后再根据新的值来判断要在某个 state 采取怎样的 action. 不过于 Qlearning 不同之处:
state
已经想好了 state
对应的 action
, 而且想好了 下一个 state_
和下一个 action_
(Qlearning 还没有想好下一个 action_
)Q(s,a)
的时候基于的是下一个 Q(s_, a_)
(Qlearning 是基于 maxQ(s_)
)这种不同之处使得 Sarsa 相对于 Qlearning, 更加的胆小. 因为 Qlearning 永远都是想着 maxQ
最大化, 因为这个 maxQ
而变得贪婪, 不考虑其他非 maxQ
的结果. 我们可以理解成 Qlearning 是一种贪婪, 大胆, 勇敢的算法, 对于错误, 死亡并不在乎. 而 Sarsa 是一种保守的算法, 他在乎每一步决策, 对于错误和死亡比较敏感. 这一点我们会在可视化的部分看出他们的不同. 两种算法都有他们的好处, 比如在实际中, 你比较在乎机器的损害, 用一种保守的算法, 在训练时就能减少损坏的次数.
通过演示,可以发现Sarsa算法中的agent非常“怕”那两个坑,以至于会长时间再左上角徘徊,也不敢靠近坑。可以说Sarsa算法中的agent比较“胆小”和谨慎。但其实最后也是可以找到宝藏的,因为还有10%的可能是随机走的。
首先我们先 import 两个模块, maze_env
是我们的环境模块, 已经编写好了, 大家可以直接在这里下载, maze_env
模块我们可以不深入研究, 如果你对编辑环境感兴趣, 可以去看看如何使用 python 自带的简单 GUI 模块 tkinter
来编写虚拟环境. 莫烦也有对应的教程. maze_env
就是用 tkinter
编写的. 而 RL_brain
这个模块是 RL 的大脑部分, 在3.2节中介绍.
from maze_env import Maze
from RL_brain import SarsaTable
下面的代码, 我们可以根据上面的图片中的算法对应起来, 这就是整个 Sarsa 最重要的迭代更新部分啦.
def update():
for episode in range(100):
# 初始化环境
observation = env.reset()
# Sarsa 根据 state 观测选择行为
action = RL.choose_action(str(observation))
while True:
# 刷新环境
env.render()
# 在环境中采取行为, 获得下一个 state_ (obervation_), reward, 和是否终止
observation_, reward, done = env.step(action)
# 根据下一个 state (obervation_) 选取下一个 action_
action_ = RL.choose_action(str(observation_))
# 从 (s, a, r, s, a) 中学习, 更新 Q_tabel 的参数 ==> Sarsa
RL.learn(str(observation), action, reward, str(observation_), action_)
# 将下一个当成下一步的 state (observation) and action
observation = observation_
action = action_
# 终止时跳出循环
if done:
break
# 大循环完毕
print('game over')
env.destroy()
if __name__ == "__main__":
env = Maze()
RL = SarsaTable(actions=list(range(env.n_actions)))
env.after(100, update)
env.mainloop()
接着上节内容, 我们来实现 RL_brain
的 SarsaTable
部分, 这也是 RL 的大脑部分, 负责决策和思考.
和之前定义 Qlearning 中的 QLearningTable
一样, 因为使用 tabular 方式的 Sarsa
和 Qlearning
的相似度极高,
class SarsaTable:
# 初始化 (与之前一样)
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
# 选行为 (与之前一样)
def choose_action(self, observation):
# 学习更新参数 (有改变)
def learn(self, s, a, r, s_):
# 检测 state 是否存在 (与之前一样)
def check_state_exist(self, state):
我们甚至可以定义一个 主class RL
, 然后将 QLearningTable
和 SarsaTable
作为 主class RL
的衍生, 这个主 RL
可以这样定义. 所以我们将之前的__init__,check_state_exist, choose_action, learn全部都放在这个主结构中, 之后根据不同的算法更改对应的内容就好了. 所以还没弄懂这些功能的朋友们, 请回到之前的教程再看一遍.
import numpy as np
import pandas as pd
class RL(object):
def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
... # 和 QLearningTable 中的代码一样
def check_state_exist(self, state):
... # 和 QLearningTable 中的代码一样
def choose_action(self, observation):
... # 和 QLearningTable 中的代码一样
def learn(self, *args):
pass # 每种的都有点不同, 所以用 pass
如果是这样定义父类的 RL
class, 通过继承关系, 那之子类 QLearningTable
class 就能简化成这样:
class QLearningTable(RL): # 继承了父类 RL
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) # 表示继承关系
def learn(self, s, a, r, s_): # learn 的方法在每种类型中有不一样, 需重新定义
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, :].max()
else:
q_target = r
self.q_table.loc[s, a] += self.lr * (q_target - q_predict)
有了父类的 RL
, 我们这次的编写就很简单, 只需要编写 SarsaTable
中 learn
这个功能就完成了. 因为其他功能都和父类是一样的. 这就是我们所有的 SarsaTable
于父类 RL
不同之处的代码. 是不是很简单.
class SarsaTable(RL): # 继承 RL class
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) # 表示继承关系
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, a_] # q_target 基于选好的 a_ 而不是 Q(s_) 的最大值
else:
q_target = r # 如果 s_ 是终止符
self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # 更新 q_table
如果想一次性看到全部代码, 请去我的Github