在文章正式开始前,请不要被强化学习的tag给吓到了,这也是我之前所遇到的一个困扰。觉得这个东西看上去很高级,需要一个完整的时间段,做详细的学习。相反,强化学习的很多算法是很符合直观思维的。 因此,强化学习的算法思想反而会是相当直观的。
另外,需要强调的是,这个算法在很多地方都有很详细的阐述了。这篇文章的工作,很多也是基于前辈的工作而继续推进的。这里也引用方便后来者进一步学习。这里再次感谢前辈的工作,确实对我有较大的帮助。
这个算法异常的简单。
虽然可能存在非常多的改进点 or 存在大量的应用场景下的不兼容,但又有点基石的感觉,还是值得学习一下,感受其中的思想的。
简单来说,就是维护一张Q表
。
Q表
,存储的表记录的是,在状态S
下,每个行为A
的Q值。一般的更新的公式 是
Q [ S , A ] = ( 1 − α ) ∗ Q [ S , A ] + α ∗ ( R + γ ∗ m a x Q [ S n e x t , : ] ) Q[S, A] = (1-\alpha)*Q[S, A] + \alpha*(R + \gamma * max{Q[S_{next}, :]}) Q[S,A]=(1−α)∗Q[S,A]+α∗(R+γ∗maxQ[Snext,:])
对于下一步是终点的更新公式 是
Q [ S , A ] = ( 1 − α ) ∗ Q [ S , A ] + α ∗ R Q[S, A] = (1-\alpha)*Q[S, A] + \alpha*R Q[S,A]=(1−α)∗Q[S,A]+α∗R
有一说一,还挺像动态规划的,这么想想,是不是觉得这个算法,初、高中生其实也都可以学会?
Q表示是一个矩阵。
到这里,对算法应该是有基本的概念。
接下来的问题是,如何对Q表更新
呢?
也就是,所谓的Q-Learning
想法非常直观。
如果在最后一步,选择对了,那么是不是就是给上一个状态的所执行的ACTION有个比较好的奖励。比如,奖励R=1
。
那有个问题,按照上面的更新的话,倒数第二步,或者那些让整个比赛没办法直接结束的状态,就没办法得到了奖励。
为了解决这个问题。Q中,加入了一个预测的概念。
即,对于那些没有办法直接获得奖励的状态,他的奖励更新(或者是叫Q表更新),会基于该状态下,执行该操作之后的 新的状态的所有操作中的最大Q值来更新 。
当然也许本次操作本身也是有对应奖励,这就另外说了。
具体算法是:
q_predict=1
q_predict=R+GAMMA*q_table[S_New].max()
,其中R
表示该操作本身的奖励,算是个局部信息,GAMMA
表示这种预测的信息的传递损失。很自然的设计,算是为了避免陷入局部最优解。ALPHA
的概念,这个学过深度学习or机器学习的都会觉得很自然的了。q_table[S, A] += ALPHA * (q_predict - q_table[S, A])
。同样也是为了避免陷入局部最优的问题。至此,算法讲完了。是不是很简单
treasure on right
的弱智游戏来实现Q-Learning这个算法。
N_STATES
表示状态数量。其实就是位置数量。EPSILON
就是 ϵ \epsilon ϵ-greedy 的 ϵ \epsilon ϵMAX_EPISODES
表示玩多少轮游戏来训练。FRESH_TIME
是用来输出的参数,控制多久刷新一次页面之类。(用来好看的)TerminalFlag
用来记录游戏结束的标志符,方便统一,就放在外面。import time
import numpy as np
import pandas as pd
N_STATES = 6
ACTIONS = ["left", "right"]
EPSILON = 0.9
ALPHA = 0.1
GAMMA = 0.9
MAX_EPISODES = 15
FRESH_TIME = 0.3
TerminalFlag = "terminal"
def build_q_table(n_states, actions):
return pd.DataFrame(
np.zeros((n_states, len(actions))),
columns=actions
)
def choose_action(state, q_table):
state_table = q_table.loc[state, :]
if (np.random.uniform() > EPSILON) or ((state_table == 0).all()):
action_name = np.random.choice(ACTIONS)
else:
action_name = state_table.idxmax()
return action_name
def get_env_feedback(S, A):
if A == "right":
if S == N_STATES - 2:
S_, R = TerminalFlag, 1
else:
S_, R = S + 1, 0
else:
S_, R = max(0, S - 1), 0
return S_, R
def update_env(S, episode, step_counter):
env_list = ["-"] * (N_STATES - 1) + ["T"]
if S == TerminalFlag:
interaction = 'Episode %s: total_steps = %s' % (episode + 1, step_counter)
print(interaction)
time.sleep(2)
else:
env_list[S] = '0'
interaction = ''.join(env_list)
print(interaction)
time.sleep(FRESH_TIME)
def rl():
q_table = build_q_table(N_STATES, ACTIONS)
for episode in range(MAX_EPISODES):
step_counter = 0
S = 0
is_terminated = False
update_env(S, episode, step_counter)
while not is_terminated:
A = choose_action(S, q_table)
S_, R = get_env_feedback(S, A)
q_predict = q_table.loc[S, A]
if S_ != TerminalFlag:
q_target = R + GAMMA * q_table.loc[S_, :].max()
else:
q_target = R
is_terminated = True
q_table.loc[S, A] += ALPHA * (q_target - q_predict)
S = S_
update_env(S, episode, step_counter + 1)
step_counter += 1
return q_table
if __name__ == '__main__':
q_table = rl()
print(q_table)