【飞桨开发者说】韩磊,台湾清华大学资讯工程学系硕士,现创业公司算法工程师,百度强化学习7日营学员
强化学习7日打卡营AI Studio课程主页:
https://aistudio.baidu.com/aistudio/course/introduce/1335
B站课程链接:
https://www.bilibili.com/video/BV1yv411i7xd
《Flappy Bird》相信大家都玩过或者看过,这款游戏在2014年火遍全球。其操作非常简单,只需要点击屏幕,让主角小鸟顺利地穿过水管之间的缝隙而不碰触任何障碍物。小鸟穿过的水管越多,得到的分数也就越高。
今天我们也来玩一玩这个游戏,不过我们使用强化学习的算法来让主角小鸟自己学会穿过水管躲避障碍,进而魔改环境,制作特殊的三人环境,让游戏进阶为《Flappy Paddle》。别愣着,看下去,看完我们一起划船。
学习这篇文章,你可以做出下面视频中的效果。这里在红黑两支队伍被淘汰之后,结束了录制。因为蓝色的算法可以跑很久,这里只是作为展示,所以没有继续录下去。
飞桨有众多方便好用的开发工具套件,其中PARL就是在强化学习方向的一个高性能、灵活的框架,目前已经在Github上开源。PARL支持大规模并行计算,同样提供了算法的可复现性保证。PARL的框架逻辑清晰,容易上手,从Model到Algorithm再到Agent,逐步构建智能体。同时PARL也提供了一些经典的强化学习算法代码示例,如PG、DDPG、A2C等,方便开发者的调研和验证。不仅如此,PARL还提供了比较完善的算法基类,这使得PARL的扩展性也很好,开发更为轻松快捷。
在我们这个项目中,使用的就是PARL这个开发工具套件。PARL的仓库,针对很多经典的强化学习方法也提供了对应的例子。本项目使用的DQN方法,也是在PARL的实现上的变化。
环境解析
我们从《Flappy Bird》这个小游戏开始。我们使用PyGame-Learning-Environment这个环境,你可以在Github上轻松的找到这个仓库。下面来分析一下上述的几个元素。
对于观测值,我们可以通过getGameState函数得到一个观测字典,其中包含了8个字段,包括了玩家(游戏里是一个小鸟)的坐标信息、速度信息、玩家距离下一水管和再下一根水管的位置信息。当然你也可以直接使用getScreenRGB函数得到画面,并以它为观测值。这里为了简单操作,我们以观测字典为例。同时,我们也能发现这个观测值是连续的。
对于动作,我们可以通过getActionSet函数得到环境所支持的动作。在《Flappy Bird》这个游戏里,只有两个动作:1、点击屏幕让小鸟展翅高飞,2、什么都不做让小鸟自由滑翔。由此我们可以知道环境接受的动作是离散有限的。
在奖惩方面,环境是这样定义:reward = 当前帧的总分 - 前一帧的总分。总分的变化有两种情况:1、玩家通过管子,得一分。2、玩家撞天花板、地板,管子则游戏失败,扣五分。
算法选择
那在这个任务中应该选择DQN还是PG呢?笔者两者都尝试了,而DQN可以轻松的训练出不错的效果,而PG却不能。我觉得可以从以下角度分析。DQN针对每一个观测的每一个动作做出评价,也就是说每一个动作都会有其价值。而在训练PG的时候,需要先跑完一个episode,然后将奖惩回传到这局游戏中的每一个动作上。
对于这个过程中的每一个动作,在这个任务中这种回传不是一个好的反馈方法。举一个例子:现在玩家在第五根管子前,真正影响面对第五根管子时动作的,是经过第四根管子之后的动作(以及第五根第六根的管子的位置,这属于观测值)。而更早之前的经过第一根、第二根、第三根管子的动作是不影响经过第五根管子的,那么这个奖惩回传的方法在这个任务中就很有问题。
搭建DQN及训练
鉴于PARL清晰的框架结构和完整的基类,我们构建Agent也更加容易。按照先model,再Algorithm,最后定义Agent的步骤来。这个项目的代码都是基于PARL中的DQN的例子的。
这里附上样例链接:
https://github.com/PaddlePaddle/PARL/tree/develop/examples/DQN
首先我们简单地设计一个包含三个隐层的网络,在PARL中的model需要继承parl.Model这样的基类。
class Model(parl.Model):
def __init__(self, act_dim):
hid0_size = 64
hid1_size = 32
hid2_size = 16
self.fc0 = layers.fc(size=hid0_size, act='relu', name="fc0")
self.fc1 = layers.fc(size=hid1_size, act='relu', name="fc1")
self.fc2 = layers.fc(size=hid2_size, act='relu', name="fc2")
self.fc3 = layers.fc(size=act_dim, act=None, name="fc3")
def value(self, obs):
h0 = self.fc0(obs)
h1 = self.fc1(h0)
h2 = self.fc2(h1)
Q = self.fc3(h2)
return Q
有一点动态图构建模型的感觉是不是?所以在模型方面你可以有更多的想法和设计,例如我还设计了以下这种模型:
class catModel(parl.Model):
def __init__(self, act_dim):
hid0_size = 64
hid1_size = 32
hid2_size = 16
self.fc0 = layers.fc(size=hid0_size, act='relu', name="catfc0")
self.fc1 = layers.fc(size=hid1_size, act='relu', name="catfc1")
self.fc2 = layers.fc(size=hid2_size, act='relu', name="catfc2")
self.fc3 = layers.fc(size=act_dim, act=None, name="catfc3")
def value(self, last_obs, obs):
oobs = fluid.layers.concat(input=[last_obs, obs], axis=-1, name='concat')
h0 = self.fc0(oobs)
h1 = self.fc1(h0)
h2 = self.fc2(h1)
Q = self.fc3(h2)
return Q
可以看出来,这里是将last_obs和obs直接concat到一起作为全连接层的输入。这里的last_obs,是上一帧的观测值,obs是当前帧的观测值。也许这种模型的效果并不会更好,但仍是一个值得尝试的想法。
接下来是algorithm,PARL中已有DQN的实现,我们直接使用PARL中提供的DQN类。像样例中一样,我们直接import算法就可以。
from parl.algorithms import DQN
当然这种写法并不适合于我刚才的第二种做法,因为第二种方法的value函数,接受的是last_obs, obs两个参数。所以这里你可以继承基类DQN或是直接重构一个。放心,有了PARL提供的样例,这个过程会非常的简单。基本上重写predict和learn两个成员函数就好。这两个函数也是之后“暴露”给Agent使用的。
predict函数用来拿到模型的输出,也就是所谓的Q值。而learn函数则是根据模型的输出和Agent拿到的经验数据去构建模型的cost,并使用优化器来最小化它,从而达到训练模型的目的。
最后是Agent,如果你使用的是我刚才第一种model,那么你可以直接使用样例中Agent的定义,但如果你使用了第二种,那当然也要修改对应的build_program、sample、predict、learn几个成员函数以能够成功的构建模型并调用Algorithm定义的函数。
接下来就可以训练我们的模型了,大概几百个episode之后,我们的Agent就能够拿到正的分数(其实这个时候,分值已经超过5分了)
修改贴图资源,制作三人环境
但是一个队伍划船总有一些孤单,能不能让多个Agent在同一环境下一起“比赛”呢?与其说把环境写“死”,每次读取来评判不同的Agent,不如就让他们在同一环境下一起出发,这种方式更加直观。
这个地方需要修改的是PyGame-Learning-Environment/ple/games/flappybird下的__init__.py文件,这个文件中定义了整个游戏的逻辑。
这里就不更具体的说了,因为涉及的更多的是pygame的知识。__init__.py中需要修改的地方大概有:
初始化定义三个player。
为每个player添加score和live属性及每个player对应的得分和死亡处理,以及游戏的score和结束条件。
设计新的actionset, 以能接受三个输入(实际上是一个输入包含三个Agent的三个action)。
设计新的observation。在此之前只返回一个观测值,但现在要针对每个player返回其对应的观测值。
图像绘制。在原来的基础上多绘制两个player。
在仓库中提供了修改好了__init__.py以及图像资源,提供了一些设计环境的想法。
成果
总结
那么在哪里能学到以上酷炫又有趣的知识呢?AI Studio上现有一门课程:《强化学习7日打卡营-世界冠军带你从零实践》,通过学习,你可以对强化学习有一个初步的了解,学到Q-learning、Sarsa、DQN、Policy Gradient等。几个清晰有趣的案例和作业,在充满趣味的同时,加强对算法和代码实现的理解。当然,也可以和我一样扩展思路,魔改环境,开发更多有趣又有技术的项目。
视频预览 :
https://www.bilibili.com/video/BV1KV411674k
AI Studio项目链接 :
https://aistudio.baidu.com/aistudio/projectdetail/609617
百度AI Studio课程平台
扫码加入课程,即可观看《世界冠军带你从零实践强化学习》的完整课节内容,动手实践案例和代码,遇到作业问题还可以到讨论区寻找答案。
最后,别忘了加入微信学习群,风里雨里我们在群里等你~
如在使用过程中有问题,可加入飞桨官方QQ群进行交流:1108045677。
如果您想详细了解更多飞桨的相关内容,请参阅以下文档。
官网地址:
https://www.paddlepaddle.org.cn
飞桨开源框架项目地址:
GitHub:
https://github.com/PaddlePaddle/Paddle
Gitee:
https://gitee.com/paddlepaddle/Paddle
飞桨生成对抗网络项目地址:
GitHub:
https://github.com/PaddlePaddle/PARL
Gitee:
https://gitee.com/paddlepaddle/PARL
END