整个算法就是一直不断更新 Q table 里的值, 然后再根据新的值来判断要在某个 state 采取怎样的 action. Qlearning 是一个 off-policy 的算法, 因为里面的 max action 让 Q table 的更新可以不基于正在经历的经验(可以是现在学习着很久以前的经验,甚至是学习他人的经验).
Q-learning中的Q函数
- s: 当前状态state
- a: 从当前状态下,采取的行动action
- s’: 今次行动所产生的新一轮state
- a’: 次回action
- R: 本次行动的奖励reward
- α: 学习速率,比如取0.01
- γ : 折扣因数,表示牺牲当前收益,换区长远收益的程度。比如取0.9
是一个 5*6 的矩阵其中 0 表示可走,1 表示障碍物
代码中 q_table 样式
up down left right
(0, 0) -0.550747 -0.533564 -0.644566 -0.410420
(0, 1) -0.811724 -0.344330 -0.362692 -0.354689
(0, 2) -0.510908 -0.571715 -0.354768 -0.354741
(1, 1) -0.297905 -0.247055 -0.478024 -0.537521
(0, 3) -0.599642 -0.512899 -0.354843 -0.354771
(0, 4) -0.546996 -0.470504 -0.354866 -0.354824
(0, 5) -0.370004 -0.361741 -0.354866 -0.397040
(2, 1) -0.259938 -0.109431 -0.464743 -0.526687
(3, 1) -0.176143 -0.403094 -0.368366 0.076880
(3, 2) -0.369096 -0.115697 -0.109689 0.296391
(4, 2) -0.069825 -0.237857 -0.136630 -0.087706
(4, 3) -0.018432 -0.078908 -0.068174 -0.066634
(4, 4) -0.117762 -0.079410 -0.066807 -0.066656
(3, 3) 0.533487 -0.066857 -0.045965 -0.223937
(2, 3) -0.164942 0.020808 -0.152385 0.767553
(4, 5) -0.069677 -0.069658 -0.066724 -0.098813
(2, 4) -0.049835 -0.063313 0.059299 0.993430
(2, 5) 0.000000 0.000000 0.000000 0.000000
q-table 为 DataFrame 类型,index 表示状态( state ),对应迷宫矩阵的索引,columns 表示动作( action )
首先运行 train()
import numpy as np
import pandas as pd
import random
import pickle
from sklearn.utils import shuffle
# 迷宫矩阵
maze = np.array(
[[0, 0, 0, 0, 0, 0, ],
[1, 0, 1, 1, 1, 1, ],
[1, 0, 1, 0, 0, 0, ],
[1, 0, 0, 0, 1, 1, ],
[0, 1, 0, 0, 0, 0, ]]
)
print(pd.DataFrame(maze))
# 起点
start_state = (0, 0)
# 终点
target_state = (2, 5)
# 要保存的q_table的文件路径
q_learning_table_path = 'q_learning_table.pkl'
class QLearningTable:
def __init__(self, alpha=0.01, gamma=0.9):
# self.alpha self.gamma 是Q函数中需要用到的两个参数
self.alpha = alpha
self.gamma = gamma
# 奖励(惩罚)值
self.reward_dict = {'reward_0': -1, 'reward_1': -0.1, 'reward_2': 1}
# 动作
self.actions = ('up', 'down', 'left', 'right')
self.q_table = pd.DataFrame(columns=self.actions)
def get_next_state_reward(self, current_state, action):
"""
:param current_state: 当前状态
:param action: 动作
:return: next_state下个状态,reward奖励值,done游戏是否结束
"""
done = False
if action == 'up':
next_state = (current_state[0] - 1, current_state[1])
elif action == 'down':
next_state = (current_state[0] + 1, current_state[1])
elif action == 'left':
next_state = (current_state[0], current_state[1] - 1)
else:
next_state = (current_state[0], current_state[1] + 1)
if next_state[0] < 0 or next_state[0] >= maze.shape[0] or next_state[1] < 0 or next_state[1] >= maze.shape[1] \
or maze[next_state[0], next_state[1]] == 1:
# 如果出界或者遇到1,保持原地不动
next_state = current_state
reward = self.reward_dict.get('reward_0')
# 此处done=True,可理解为进入陷阱,游戏结束,done=False,可理解为在原地白走一步,受到了一次惩罚,但游戏还未结束
# done = True
elif next_state == target_state: # 到达目标
reward = self.reward_dict.get('reward_2')
done = True
else: # maze[next_state[0],next_state[1]] == 0
reward = self.reward_dict.get('reward_1')
return next_state, reward, done
# 根据返回的reward和next_state更新q_table
def learn(self, current_state, action, reward, next_state):
self.check_state_exist(next_state)
q_sa = self.q_table.loc[current_state, action]
max_next_q_sa = self.q_table.loc[next_state, :].max()
# 套用公式:Q函数
new_q_sa = q_sa + self.alpha * (reward + self.gamma * max_next_q_sa - q_sa)
# 更新q_table
self.q_table.loc[current_state, action] = new_q_sa
# 如果state不在q_table中,在q_tabel中添加该state
def check_state_exist(self, state):
if state not in self.q_table.index:
self.q_table.loc[state] = pd.Series(np.zeros(len(self.actions)), index=self.actions)
# 旋转执行动作
def choose_action(self, state, random_num=0.8):
series = pd.Series(self.q_table.loc[state])
# 以0.8的概率执行action,尝试更多的可能性。总是做最好的选择,意味着你可能会错过一些从未探索的道路。
# 为了避免这种情况,可以添加一个随机项,而未必总是选择对当前来说最好的action。
if random.random() > random_num:
action = random.choice(self.actions)
else:
# 因为pd.Series数据的最大值可能出现多个,而argmax()只取第一个,故使用sklearn中的shuffle将其打乱顺序,
# 随机选取最大值的索引,选取最大值的action有利于q_table快速收敛
ss = shuffle(series)
action = ss.argmax()
return action
# 训练
def train():
q_learning_table = QLearningTable()
# 迭代次数
iterate_num = 500
for _ in range(iterate_num):
# 每次迭代 从start_state开始
current_state = start_state
while True:
# 先检查current_state是否已在q_table中,注意将current_state以为字符串的形式存到q_table中
q_learning_table.check_state_exist(str(current_state))
# 获取当前状态的执行动作
action = q_learning_table.choose_action(str(current_state))
# 根据当前状态current_state和动作action,获取下个状态next_state,奖励值reward以及游戏是否结束done
next_state, reward, done = q_learning_table.get_next_state_reward(current_state, action)
# 开始学习,更新q_table
q_learning_table.learn(str(current_state), action, reward, str(next_state))
# 如果游戏结束,跳出while循环,进入下次迭代
if done:
break
# current_state跳转到下个状态
current_state = next_state
print('game over')
# 保存对象q_learning_table到文件q_learning_table_path
with open(q_learning_table_path, 'wb') as pkl_file:
pickle.dump(q_learning_table, pkl_file)
tain() 运行完后生成一个文件 q_learning_table.pkl,里面存放到是训练好的 QLearningTable 对象模型
然后运行下面一段代码 predict() 用来测试模型
# 预测
def predict():
# 读取q_table
with open(q_learning_table_path, 'rb') as pkl_file:
q_learning_table = pickle.load(pkl_file)
print('start_state:{}'.format(start_state))
current_state = start_state
step = 0
while True:
step = step + 1
action = q_learning_table.choose_action(str(current_state), random_num=1)
# 预测阶段,reward用不到了,故使用_代替
next_state, _, done = q_learning_table.get_next_state_reward(current_state, action)
# 输出动作和下个状态
print('step:{step}, action: {action}, state: {state}'.format(step=step, action=action, state=next_state))
# 如果done或者步数超过100,游戏结束退出
if done or step > 100:
if next_state == target_state:
print('success')
else:
print('fail')
break
# 跳转到下个状态
else:
current_state = next_state
运行结果
start_state:(0, 0)
step:1, action: right, state: (0, 1)
step:2, action: down, state: (1, 1)
step:3, action: down, state: (2, 1)
step:4, action: down, state: (3, 1)
step:5, action: right, state: (3, 2)
step:6, action: right, state: (3, 3)
step:7, action: up, state: (2, 3)
step:8, action: right, state: (2, 4)
step:9, action: right, state: (2, 5)
success
https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/tree/master/contents/2_Q_Learning_maze