一、Q-learning与SARSA区别
Q-learning为offpolicy(通过之前的历史,也可以是别人的历史,学习者和决策者不一定相同),target使用greedy,action用ε-greedy。行动策略和评估策略不是一个策略。
SARSA是on-policy的更新方式(边决策边学习,学习者也是决策者),它的行动策略和评估策略都是ε-greedy策略。与Q-learning相比更保守。
二、SARSA算法代码实现-test
from maze_env import Maze
from RL_brain import SarsaTable
def update():
for episode in range(100):
observation = env.reset()
action = RL.choose_action(str(observation))#基于观测值进行动作的选择
while True:
env.render()#更新环境
observation_, reward, done = env.step(action)#执行动作得到下一个观测值、获得的奖励以及是否结束
action_ = RL.choose_action(str(observation_))#下一个action就是下个回合的action
RL.learn(str(observation), action, reward, str(observation_), action_)#状态转移,考虑下一个action
observation = observation_
action = action_
if done:
break
# end of game
print('game over')
env.destroy()
if __name__ == "__main__":
env = Maze()
RL = SarsaTable(actions=list(range(env.n_actions)))
env.after(100, update)
env.mainloop()
三、SARSA算法代码实现-brain
import numpy as np
import pandas as pd
class RL(object):#这是一个父类,SARSA table可以继承他,分为三部分1.__init__的这个条件2.检查Q表中是否有state没有就补上3.选择动作,90%10%
def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
self.actions = action_space # a list
self.lr = learning_rate
self.gamma = reward_decay
self.epsilon = e_greedy
self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)
def check_state_exist(self, state):
if state not in self.q_table.index:
# append new state to q table
self.q_table = self.q_table.append(
pd.Series(
[0]*len(self.actions),
index=self.q_table.columns,
name=state,
)
)
def choose_action(self, observation):
self.check_state_exist(observation)
# 动作选择
if np.random.rand() < self.epsilon:
# 选择最好的动作
state_action = self.q_table.loc[observation, :]
action = np.random.choice(state_action[state_action == np.max(state_action)].index)
else:
# 随机选择动作
action = np.random.choice(self.actions)
return action
def learn(self, *args):#这部分Qlearning和SARSA是不同的,需要分开编写
pass
# off-policy
class QLearningTable(RL):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
def learn(self, s, a, r, s_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, :].max()
else:
q_target = r
self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # update
# on-policy
class SarsaTable(RL):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)#通过父类传入参数
def learn(self, s, a, r, s_, a_):#与Q学习相比,多了一个下一个action也要被学习
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, a_] # next state is not terminal,这一步不同于Q学习,这里不是最大的值,而是采取下一个行动的下一个值
else:
q_target = r # next state is terminal
self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # update