原文链接
https://mp.weixin.qq.com/s/nm...
效果展示
参见:https://zhuanlan.zhihu.com/p/...
原理简介
原理其实在这篇文章里讲过:
不过今天我们将尝试只用Q-Learning算法,而不是DQN来玩FlappyBird这款经典的小游戏。有些懒癌患者可能不想点击上面的超链接去看那么长的文章,所以这里我们先简单介绍一下Q-Learning算法,然后再说下如何用这个算法玩FlappyBird这个游戏,以及我们的代码实现。
1.Q-Learning算法
好像直接讲算法有点突兀,那就先举一个网上比较经典的例子吧:
假设需要创建一个智能体agent,他需要去捡右上角的蓝宝石,agent每次可以移动一个方块的距离,若agent踩到陷阱就会死掉,那么如何设计策略才能让agent学会去捡右上角的蓝宝石同时不会在路途中因为踩到陷阱而死掉呢?
一个比较直观的方案就是创建类似下图这样的表格:
即对于当前的状态S(也就是agent现在在哪个方块中),我们都可以为agent计算出每种动作A(即上下左右移动)最大的未来期望奖励R,从而可以知道每个状态应当采取的最佳动作。将上面的表格弄成这样,就是我们熟悉的Q-table啦(Q代表动作的质量):
那么我们如何获得Q-table中的值呢?这时候就需要Q-Learning算法闪亮登场啦。该算法的基本流程如下:
Q-Learning的思想基于价值迭代,直观地理解就是每次利用新得到的reward和原本的Q值来更新现在的Q值。其数学形式表示为:
其中Q是当前的table,Q*是更新后的table,r是在状态s时采取动作a后获得的奖励,α可以当作是学习率,γ一般称为折扣因子,用于定义未来奖励的重要性。max{Q(s', a')}用于计算进行动作a后进入新的状态s'时可以获得的最大奖励。至于这个公式咋来的,还是请参见:
感觉再打下去还不如你们自己跳转去看了。
2.FlappyBird游戏实现
学习了Q-learning算法之后,我们需要先实现一下我们的FlappyBird小游戏,然后再考虑怎么把算法用在这个游戏上?
显然,我们之前已经写过一个这样的游戏了:
当然,为了方便实现后面的算法,我们对游戏做了一些微小的改动,即我们把小鸟上下移动的速度都做了取整化处理(也就是速度每次加1或者减1了,并且假设每帧的时间为单位1,该帧内小鸟的速度仍然假设为保持不变)。
3.如何用Q-Learning玩FlappyBird?
其实很简单,只需要明确状态state,动作action和奖赏reward,然后往算法里套就OK啦~
对于状态,我们假设小鸟当前的状态定义为:
s = (delta_x, delta_y, speed)
--delta_x:小鸟和即将通过那组管道的下半部分,水平方向上的距离
--delta_y:小鸟和即将通过那组管道的下半部分,竖直方向上的距离
--speed:小鸟当前的速度
一个丑陋的示意图:
当然delta_x和delta_y也可以用其他方式定义,这个无所谓的。
动作的话无非是这样:
a = 1 向上飞一下
a = 0 啥都不做
奖赏的话,可以这样:
reward = 1 平安无事
reward = -1000000 小鸟死掉了
reward = 5 小鸟成功通过了一组管道
反正只要合理,应该大差不差。
接下来的事情就是套算法了,具体而言,其核心代码实现如下:
'''q learning agent'''
class QLearningAgent():
def __init__(self, mode, **kwargs):
self.mode = mode
# learning rate
self.learning_rate = 0.7
# discount factor(also named discount rate)
self.discount_factor = 0.95
# store the necessary history data, the format is [previous_state, previous_action, state, reward]
self.history_storage = []
# store the q values, the last dimension is [value_for_do_nothing, value_for_flappy]
self.qvalues_storage = np.zeros((130, 130, 20, 2))
# store the score for each episode
self.scores_storage = []
# previous state
self.previous_state = []
# 0 means do nothing, 1 means flappy
self.previous_action = 0
# number of episode
self.num_episode = 0
# the max score so far
self.max_score = 0
'''make a decision'''
def act(self, delta_x, delta_y, bird_speed):
if not self.previous_state:
self.previous_state = [delta_x, delta_y, bird_speed]
return self.previous_action
if self.mode == 'train':
state = [delta_x, delta_y, bird_speed]
self.history_storage.append([self.previous_state, self.previous_action, state, 0])
self.previous_state = state
# make a decision according to the qvalues
if self.qvalues_storage[delta_x, delta_y, bird_speed][0] >= self.qvalues_storage[delta_x, delta_y, bird_speed][1]:
self.previous_action = 0
else:
self.previous_action = 1
return self.previous_action
'''set reward'''
def setReward(self, reward):
if self.history_storage:
self.history_storage[-1][3] = reward
'''update the qvalues_storage after an episode'''
def update(self, score, is_logging=True):
self.num_episode += 1
self.max_score = max(self.max_score, score)
self.scores_storage.append(score)
if is_logging:
print('Episode: %s, Score: %s, Max Score: %s' % (self.num_episode, score, self.max_score))
if self.mode == 'train':
history = list(reversed(self.history_storage))
# penalize last num_penalization states before crash
num_penalization = 2
for item in history:
previous_state, previous_action, state, reward = item
if num_penalization > 0:
num_penalization -= 1
reward = -1000000
x_0, y_0, z_0 = previous_state
x_1, y_1, z_1 = state
self.qvalues_storage[x_0, y_0, z_0, previous_action] = (1 - self.learning_rate) * self.qvalues_storage[x_0, y_0, z_0, previous_action] +\
self.learning_rate * (reward + self.discount_factor * max(self.qvalues_storage[x_1, y_1, z_1]))
self.history_storage = []
'''save the model'''
def saveModel(self, modelpath):
data = {
'num_episode': self.num_episode,
'max_score': self.max_score,
'scores_storage': self.scores_storage,
'qvalues_storage': self.qvalues_storage
}
with open(modelpath, 'wb') as f:
pickle.dump(data, f)
print('[INFO]: save checkpoints in %s...' % modelpath)
'''load the model'''
def loadModel(self, modelpath):
print('[INFO]: load checkpoints from %s...' % modelpath)
with open(modelpath, 'rb') as f:
data = pickle.load(f)
self.num_episode = data.get('num_episode')
self.qvalues_storage = data.get('qvalues_storage')
OK,大功告成,完整源代码详见相关文件~
参考文献:
[1].https://blog.csdn.net/qq_30615903/article/details/80739243
[2].https://www.zhihu.com/question/26408259
[3].https://zhuanlan.zhihu.com/p/35724704