【强化学习】Q-learning与SARSAS算法比较与SARSA算法实现

一、Q-learning与SARSA区别

Q-learning为offpolicy(通过之前的历史,也可以是别人的历史,学习者和决策者不一定相同),target使用greedy,action用ε-greedy。行动策略和评估策略不是一个策略。

SARSA是on-policy的更新方式(边决策边学习,学习者也是决策者),它的行动策略和评估策略都是ε-greedy策略。与Q-learning相比更保守。

二、SARSA算法代码实现-test

from maze_env import Maze

from RL_brain import SarsaTable

def update():

    for episode in range(100):

        observation = env.reset()

        action = RL.choose_action(str(observation))#基于观测值进行动作的选择

        while True:

            env.render()#更新环境

            observation_, reward, done = env.step(action)#执行动作得到下一个观测值、获得的奖励以及是否结束

            action_ = RL.choose_action(str(observation_))#下一个action就是下个回合的action

            RL.learn(str(observation), action, reward, str(observation_), action_)#状态转移,考虑下一个action

            observation = observation_

            action = action_

            if done:

                break

    # end of game

    print('game over')

    env.destroy()

if __name__ == "__main__":

    env = Maze()

    RL = SarsaTable(actions=list(range(env.n_actions)))

    env.after(100, update)

env.mainloop()

三、SARSA算法代码实现-brain

    import numpy as np

import pandas as pd

class RL(object):#这是一个父类,SARSA table可以继承他,分为三部分1.__init__的这个条件2.检查Q表中是否有state没有就补上3.选择动作,90%10%

    def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):

        self.actions = action_space  # a list

        self.lr = learning_rate

        self.gamma = reward_decay

        self.epsilon = e_greedy

        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)

    def check_state_exist(self, state):

        if state not in self.q_table.index:

            # append new state to q table

            self.q_table = self.q_table.append(

                pd.Series(

                    [0]*len(self.actions),

                    index=self.q_table.columns,

                    name=state,

                )

            )

    def choose_action(self, observation):

        self.check_state_exist(observation)

        # 动作选择

        if np.random.rand() < self.epsilon:

            # 选择最好的动作

            state_action = self.q_table.loc[observation, :]

            action = np.random.choice(state_action[state_action == np.max(state_action)].index)

        else:

            # 随机选择动作

            action = np.random.choice(self.actions)

        return action

    def learn(self, *args):#这部分Qlearning和SARSA是不同的,需要分开编写

        pass

# off-policy

class QLearningTable(RL):

    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):

        super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

    def learn(self, s, a, r, s_):

        self.check_state_exist(s_)

        q_predict = self.q_table.loc[s, a]

        if s_ != 'terminal':

            q_target = r + self.gamma * self.q_table.loc[s_, :].max() 

        else:

            q_target = r

        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # update

# on-policy

class SarsaTable(RL):

    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):

        super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)#通过父类传入参数

    def learn(self, s, a, r, s_, a_):#与Q学习相比,多了一个下一个action也要被学习

        self.check_state_exist(s_)

        q_predict = self.q_table.loc[s, a]

        if s_ != 'terminal':

            q_target = r + self.gamma * self.q_table.loc[s_, a_]  # next state is not terminal,这一步不同于Q学习,这里不是最大的值,而是采取下一个行动的下一个值

        else:

            q_target = r  # next state is terminal

        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # update

你可能感兴趣的:(人工智能,机器学习,python,算法)