版权声明:本文为博主原创文章,未经博主允许不得转载。
目录(?)[+]
2013年DeepMind 在NIPS上发表Playing Atari with Deep Reinforcement Learning 一文,提出了DQN(Deep Q Network)算法,实现端到端学习玩Atari游戏,即只有像素输入,看着屏幕玩游戏。Deep Mind就凭借这个应用以6亿美元被Google收购。由于DQN的开源,在github上涌现了大量各种版本的DQN程序。但大多是复现Atari的游戏,代码量很大,也不好理解。
Flappy Bird是个极其简单又困难的游戏,风靡一时。在很早之前,就有人使用Q-Learning 算法来实现完Flappy Bird。http://sarvagyavaish.github.io/FlappyBirdRL/
但是这个的实现是通过获取小鸟的具体位置信息来实现的。
能否使用DQN来实现通过屏幕学习玩Flappy Bird是一个有意思的挑战。(话说本人和朋友在去年年底也考虑了这个idea,但当时由于不知道如何截取游戏屏幕只能使用具体位置来学习,不过其实也成功了)
最近,github上有人放出使用DQN玩Flappy Bird的代码,https://github.com/yenchenlin1994/DeepLearningFlappyBird【1】
该repo通过结合之前的repo成功实现了这个想法。这个repo对整个实现过程进行了较详细的分析,但是由于其DQN算法的代码基本采用别人的repo,代码较为混乱,不易理解。
为此,本人改写了一个版本https://github.com/songrotek/DRL-FlappyBird
对DQN代码进行了重新改写。本质上对其做了类的封装,从而使代码更具通用性。可以方便移植到其他应用。
当然,本文的目的是借Flappy Bird DQN这个代码来详细分析一下DQN算法极其使用。
这个是NIPS13版本的伪代码:
<code class="hljs livecodeserver has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">Initialize replay memory D <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">to</span> size N Initialize action-<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">value</span> <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">function</span> <span class="hljs-title" style="box-sizing: border-box;">Q</span> <span class="hljs-title" style="box-sizing: border-box;">with</span> <span class="hljs-title" style="box-sizing: border-box;">random</span> <span class="hljs-title" style="box-sizing: border-box;">weights</span></span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> episode = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>, M <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">do</span> Initialize state s_1 <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> t = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>, T <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">do</span> With probability ϵ select <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">random</span> action a_t otherwise select a_t=max_a Q($s_t$,<span class="hljs-operator" style="box-sizing: border-box;">a</span>; $θ<span class="hljs-title" style="box-sizing: border-box;">_i</span>$) Execute action a_t <span class="hljs-operator" style="box-sizing: border-box;">in</span> emulator <span class="hljs-operator" style="box-sizing: border-box;">and</span> observe r_t <span class="hljs-operator" style="box-sizing: border-box;">and</span> s_(t+<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>) Store transition (s_t,a_t,r_t,s_(t+<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>)) <span class="hljs-operator" style="box-sizing: border-box;">in</span> D Sample <span class="hljs-operator" style="box-sizing: border-box;">a</span> minibatch <span class="hljs-operator" style="box-sizing: border-box;">of</span> transitions (s_j,a_j,r_j,s_(j+<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>)) <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">from</span> D Set y_j:= r_j <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> terminal s_(j+<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>) r_j+γ*max_(<span class="hljs-operator" style="box-sizing: border-box;">a</span>^<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">' ) Q(s_(j+1),a'</span>; θ<span class="hljs-title" style="box-sizing: border-box;">_i</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> non-terminal s_(j+<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>) Perform <span class="hljs-operator" style="box-sizing: border-box;">a</span> gradient step <span class="hljs-command" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">on</span> (<span class="hljs-title" style="box-sizing: border-box;">y_j-Q</span>(<span class="hljs-title" style="box-sizing: border-box;">s_j</span>,<span class="hljs-title" style="box-sizing: border-box;">a_j</span>; θ<span class="hljs-title" style="box-sizing: border-box;">_i</span>))^<span class="hljs-title" style="box-sizing: border-box;">2</span> <span class="hljs-title" style="box-sizing: border-box;">with</span> <span class="hljs-title" style="box-sizing: border-box;">respect</span> <span class="hljs-title" style="box-sizing: border-box;">to</span> θ</span> <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">end</span> <span class="hljs-title" style="box-sizing: border-box;">for</span></span> <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">end</span> <span class="hljs-title" style="box-sizing: border-box;">for</span></span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li></ul>
基本的分析详见Paper Reading 1 - Playing Atari with Deep Reinforcement Learning
基础知识详见Deep Reinforcement Learning 基础知识(DQN方面)
本文主要从代码实现的角度来分析如何编写Flappy Bird DQN的代码
首先,FlappyBird的游戏已经编写好,是现成的。提供了很简单的接口:
<code class="hljs fix has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-attribute" style="box-sizing: border-box;">nextObservation,reward,terminal </span>=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;"> game.frame_step(action)</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul>
即输入动作,输出执行完动作的屏幕截图,得到的反馈reward,以及游戏是否结束。
那么,现在先把DQN想象为一个大脑,这里我们也用BrainDQN类来表示,这个类只需获取感知信息也就是上面说的观察(截图),反馈以及是否结束,然后输出动作即可。
完美的代码封装应该是这样。具体DQN里面如何存储。如何训练是外部不关心的。
因此,我们的FlappyBirdDQN代码只有如下这么短:
<code class="hljs python has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># -------------------------</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Project: Deep Q-Learning on Flappy Bird</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Author: Flood Sung</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Date: 2016.3.21</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># -------------------------</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">import</span> cv2 <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">import</span> sys sys.path.append(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"game/"</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">import</span> wrapped_flappy_bird <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">as</span> game <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">from</span> BrainDQN <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">import</span> BrainDQN <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">import</span> numpy <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">as</span> np <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># preprocess raw image to 80*80 gray image</span> <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">preprocess</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(observation)</span>:</span> observation = cv2.cvtColor(cv2.resize(observation, (<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">80</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">80</span>)), cv2.COLOR_BGR2GRAY) ret, observation = cv2.threshold(observation,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">255</span>,cv2.THRESH_BINARY) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> np.reshape(observation,(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">80</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">80</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>)) <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">playFlappyBird</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">()</span>:</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Step 1: init BrainDQN</span> brain = BrainDQN() <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Step 2: init Flappy Bird Game</span> flappyBird = game.GameState() <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Step 3: play game</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Step 3.1: obtain init state</span> action0 = np.array([<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>]) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># do nothing</span> observation0, reward0, terminal = flappyBird.frame_step(action0) observation0 = cv2.cvtColor(cv2.resize(observation0, (<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">80</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">80</span>)), cv2.COLOR_BGR2GRAY) ret, observation0 = cv2.threshold(observation0,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">255</span>,cv2.THRESH_BINARY) brain.setInitState(observation0) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Step 3.2: run the game</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">while</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>!= <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>: action = brain.getAction() nextObservation,reward,terminal = flappyBird.frame_step(action) nextObservation = preprocess(nextObservation) brain.setPerception(nextObservation,action,reward,terminal) <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">main</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">()</span>:</span> playFlappyBird() <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> __name__ == <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'__main__'</span>: main()</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li><li style="box-sizing: border-box; padding: 0px 5px;">30</li><li style="box-sizing: border-box; padding: 0px 5px;">31</li><li style="box-sizing: border-box; padding: 0px 5px;">32</li><li style="box-sizing: border-box; padding: 0px 5px;">33</li><li style="box-sizing: border-box; padding: 0px 5px;">34</li><li style="box-sizing: border-box; padding: 0px 5px;">35</li><li style="box-sizing: border-box; padding: 0px 5px;">36</li><li style="box-sizing: border-box; padding: 0px 5px;">37</li><li style="box-sizing: border-box; padding: 0px 5px;">38</li><li style="box-sizing: border-box; padding: 0px 5px;">39</li><li style="box-sizing: border-box; padding: 0px 5px;">40</li><li style="box-sizing: border-box; padding: 0px 5px;">41</li><li style="box-sizing: border-box; padding: 0px 5px;">42</li><li style="box-sizing: border-box; padding: 0px 5px;">43</li><li style="box-sizing: border-box; padding: 0px 5px;">44</li></ul>
核心部分就在while循环里面,由于要讲图像转换为80x80的灰度图,因此,加了一个preprocess预处理函数。
这里,显然只有有游戏引擎,换一个游戏是一样的写法,非常方便。
接下来就是编写BrainDQN.py 我们的游戏大脑
<code class="hljs python has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-class" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">class</span> <span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">BrainDQN</span>:</span> <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">__init__</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self)</span>:</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># init replay memory</span> self.replayMemory = deque() <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># init Q network</span> self.createQNetwork() <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">createQNetwork</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self)</span>:</span> <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">trainQNetwork</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self)</span>:</span> <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">setPerception</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self,nextObservation,action,reward,terminal)</span>:</span> <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">getAction</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self)</span>:</span> <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">setInitState</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self,observation)</span>:</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li></ul>
基本的架构也就只需要上面这几个函数,其他的都是多余了,接下来就是编写每一部分的代码。
也就是createQNetwork部分,这里采用如下图的结构(转自【1】):
这里就不讲解整个流程了。主要是针对具体的输入类型和输出设计卷积和全连接层。
代码如下:
<code class="hljs avrasm has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"> def createQNetwork(self): <span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;"># network weights</span> W_conv1 = self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.weight</span>_variable([<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">8</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">8</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">32</span>]) b_conv1 = self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.bias</span>_variable([<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">32</span>]) W_conv2 = self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.weight</span>_variable([<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">32</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">64</span>]) b_conv2 = self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.bias</span>_variable([<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">64</span>]) W_conv3 = self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.weight</span>_variable([<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">64</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">64</span>]) b_conv3 = self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.bias</span>_variable([<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">64</span>]) W_fc1 = self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.weight</span>_variable([<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1600</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">512</span>]) b_fc1 = self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.bias</span>_variable([<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">512</span>]) W_fc2 = self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.weight</span>_variable([<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">512</span>,self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.ACTION</span>]) b_fc2 = self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.bias</span>_variable([self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.ACTION</span>]) <span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;"># input layer</span> self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.stateInput</span> = tf<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.placeholder</span>(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"float"</span>,[None,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">80</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">80</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>]) <span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;"># hidden layers</span> h_conv1 = tf<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.nn</span><span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.relu</span>(self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.conv</span>2d(self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.stateInput</span>,W_conv1,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>) + b_conv1) h_pool1 = self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.max</span>_pool_2x2(h_conv1) h_conv2 = tf<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.nn</span><span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.relu</span>(self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.conv</span>2d(h_pool1,W_conv2,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>) + b_conv2) h_conv3 = tf<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.nn</span><span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.relu</span>(self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.conv</span>2d(h_conv2,W_conv3,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>) + b_conv3) h_conv3_flat = tf<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.reshape</span>(h_conv3,[-<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1600</span>]) h_fc1 = tf<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.nn</span><span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.relu</span>(tf<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.matmul</span>(h_conv3_flat,W_fc1) + b_fc1) <span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;"># Q Value layer</span> self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.QValue</span> = tf<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.matmul</span>(h_fc1,W_fc2) + b_fc2 self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.actionInput</span> = tf<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.placeholder</span>(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"float"</span>,[None,self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.ACTION</span>]) self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.yInput</span> = tf<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.placeholder</span>(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"float"</span>, [None]) Q_action = tf<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.reduce</span>_sum(tf<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.mul</span>(self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.QValue</span>, self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.actionInput</span>), reduction_indices = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>) self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.cost</span> = tf<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.reduce</span>_mean(tf<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.square</span>(self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.yInput</span> - Q_action)) self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.trainStep</span> = tf<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.train</span><span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.AdamOptimizer</span>(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1e-6</span>)<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.minimize</span>(self<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.cost</span>)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li><li style="box-sizing: border-box; padding: 0px 5px;">30</li><li style="box-sizing: border-box; padding: 0px 5px;">31</li><li style="box-sizing: border-box; padding: 0px 5px;">32</li><li style="box-sizing: border-box; padding: 0px 5px;">33</li><li style="box-sizing: border-box; padding: 0px 5px;">34</li><li style="box-sizing: border-box; padding: 0px 5px;">35</li><li style="box-sizing: border-box; padding: 0px 5px;">36</li><li style="box-sizing: border-box; padding: 0px 5px;">37</li><li style="box-sizing: border-box; padding: 0px 5px;">38</li><li style="box-sizing: border-box; padding: 0px 5px;">39</li><li style="box-sizing: border-box; padding: 0px 5px;">40</li></ul>
记住输出是Q值,关键要计算出cost,里面关键是计算Q_action的值,即该state和action下的Q值。由于actionInput是one hot vector的形式,因此tf.mul(self.QValue, self.actionInput)正好就是该action下的Q值。
这部分是代码的关键部分,主要是要计算y值,也就是target Q值。
<code class="hljs python has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"> <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">trainQNetwork</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self)</span>:</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Step 1: obtain random minibatch from replay memory</span> minibatch = random.sample(self.replayMemory,self.BATCH_SIZE) state_batch = [data[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>] <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> data <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">in</span> minibatch] action_batch = [data[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>] <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> data <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">in</span> minibatch] reward_batch = [data[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>] <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> data <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">in</span> minibatch] nextState_batch = [data[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>] <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> data <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">in</span> minibatch] <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Step 2: calculate y </span> y_batch = [] QValue_batch = self.QValue.eval(feed_dict={self.stateInput:nextState_batch}) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> i <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">in</span> range(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>,self.BATCH_SIZE): terminal = minibatch[i][<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>] <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> terminal: y_batch.append(reward_batch[i]) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">else</span>: y_batch.append(reward_batch[i] + GAMMA * np.max(QValue_batch[i])) self.trainStep.run(feed_dict={ self.yInput : y_batch, self.actionInput : action_batch, self.stateInput : state_batch })</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li></ul>
其他部分就比较容易了,这里直接贴出完整的代码:
<code class="hljs python has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># -----------------------------</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># File: Deep Q-Learning Algorithm</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Author: Flood Sung</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Date: 2016.3.21</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># -----------------------------</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">import</span> tensorflow <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">as</span> tf <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">import</span> numpy <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">as</span> np <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">import</span> random <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">from</span> collections <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">import</span> deque <span class="hljs-class" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">class</span> <span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">BrainDQN</span>:</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Hyper Parameters:</span> ACTION = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span> FRAME_PER_ACTION = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span> GAMMA = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.99</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># decay rate of past observations</span> OBSERVE = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">100000.</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># timesteps to observe before training</span> EXPLORE = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">150000.</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># frames over which to anneal epsilon</span> FINAL_EPSILON = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.0</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># final value of epsilon</span> INITIAL_EPSILON = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.0</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># starting value of epsilon</span> REPLAY_MEMORY = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">50000</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># number of previous transitions to remember</span> BATCH_SIZE = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">32</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># size of minibatch</span> <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">__init__</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self)</span>:</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># init replay memory</span> self.replayMemory = deque() <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># init Q network</span> self.createQNetwork() <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># init some parameters</span> self.timeStep = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span> self.epsilon = self.INITIAL_EPSILON <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">createQNetwork</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self)</span>:</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># network weights</span> W_conv1 = self.weight_variable([<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">8</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">8</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">32</span>]) b_conv1 = self.bias_variable([<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">32</span>]) W_conv2 = self.weight_variable([<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">32</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">64</span>]) b_conv2 = self.bias_variable([<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">64</span>]) W_conv3 = self.weight_variable([<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">64</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">64</span>]) b_conv3 = self.bias_variable([<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">64</span>]) W_fc1 = self.weight_variable([<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1600</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">512</span>]) b_fc1 = self.bias_variable([<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">512</span>]) W_fc2 = self.weight_variable([<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">512</span>,self.ACTION]) b_fc2 = self.bias_variable([self.ACTION]) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># input layer</span> self.stateInput = tf.placeholder(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"float"</span>,[<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">80</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">80</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>]) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># hidden layers</span> h_conv1 = tf.nn.relu(self.conv2d(self.stateInput,W_conv1,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>) + b_conv1) h_pool1 = self.max_pool_2x2(h_conv1) h_conv2 = tf.nn.relu(self.conv2d(h_pool1,W_conv2,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>) + b_conv2) h_conv3 = tf.nn.relu(self.conv2d(h_conv2,W_conv3,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>) + b_conv3) h_conv3_flat = tf.reshape(h_conv3,[-<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1600</span>]) h_fc1 = tf.nn.relu(tf.matmul(h_conv3_flat,W_fc1) + b_fc1) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Q Value layer</span> self.QValue = tf.matmul(h_fc1,W_fc2) + b_fc2 self.actionInput = tf.placeholder(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"float"</span>,[<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>,self.ACTION]) self.yInput = tf.placeholder(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"float"</span>, [<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>]) Q_action = tf.reduce_sum(tf.mul(self.QValue, self.actionInput), reduction_indices = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>) self.cost = tf.reduce_mean(tf.square(self.yInput - Q_action)) self.trainStep = tf.train.AdamOptimizer(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1e-6</span>).minimize(self.cost) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># saving and loading networks</span> saver = tf.train.Saver() self.session = tf.InteractiveSession() self.session.run(tf.initialize_all_variables()) checkpoint = tf.train.get_checkpoint_state(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"saved_networks"</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> checkpoint <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">and</span> checkpoint.model_checkpoint_path: saver.restore(self.session, checkpoint.model_checkpoint_path) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"Successfully loaded:"</span>, checkpoint.model_checkpoint_path <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">else</span>: <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"Could not find old network weights"</span> <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">trainQNetwork</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self)</span>:</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Step 1: obtain random minibatch from replay memory</span> minibatch = random.sample(self.replayMemory,self.BATCH_SIZE) state_batch = [data[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>] <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> data <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">in</span> minibatch] action_batch = [data[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>] <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> data <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">in</span> minibatch] reward_batch = [data[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>] <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> data <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">in</span> minibatch] nextState_batch = [data[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>] <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> data <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">in</span> minibatch] <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Step 2: calculate y </span> y_batch = [] QValue_batch = self.QValue.eval(feed_dict={self.stateInput:nextState_batch}) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> i <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">in</span> range(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>,self.BATCH_SIZE): terminal = minibatch[i][<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>] <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> terminal: y_batch.append(reward_batch[i]) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">else</span>: y_batch.append(reward_batch[i] + GAMMA * np.max(QValue_batch[i])) self.trainStep.run(feed_dict={ self.yInput : y_batch, self.actionInput : action_batch, self.stateInput : state_batch }) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># save network every 100000 iteration</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> self.timeStep % <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">10000</span> == <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>: saver.save(self.session, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'saved_networks/'</span> + <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'network'</span> + <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'-dqn'</span>, global_step = self.timeStep) <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">setPerception</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self,nextObservation,action,reward,terminal)</span>:</span> newState = np.append(nextObservation,self.currentState[:,:,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>:],axis = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>) self.replayMemory.append((self.currentState,action,reward,newState,terminal)) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> len(self.replayMemory) > self.REPLAY_MEMORY: self.replayMemory.popleft() <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> self.timeStep > self.OBSERVE: <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Train the network</span> self.trainQNetwork() self.currentState = newState self.timeStep += <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span> <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">getAction</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self)</span>:</span> QValue = self.QValue.eval(feed_dict= {self.stateInput:[self.currentState]})[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>] action = np.zeros(self.ACTION) action_index = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> self.timeStep % self.FRAME_PER_ACTION == <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>: <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> random.random() <= self.epsilon: action_index = random.randrange(self.ACTION) action[action_index] = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">else</span>: action_index = np.argmax(QValue) action[action_index] = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">else</span>: action[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>] = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># do nothing</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># change episilon</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> self.epsilon > self.FINAL_EPSILON <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">and</span> self.timeStep > self.OBSERVE: self.epsilon -= (self.INITIAL_EPSILON - self.FINAL_EPSILON)/self.EXPLORE <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> action <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">setInitState</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self,observation)</span>:</span> self.currentState = np.stack((observation, observation, observation, observation), axis = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>) <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">weight_variable</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self,shape)</span>:</span> initial = tf.truncated_normal(shape, stddev = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.01</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> tf.Variable(initial) <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">bias_variable</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self,shape)</span>:</span> initial = tf.constant(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.01</span>, shape = shape) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> tf.Variable(initial) <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">conv2d</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self,x, W, stride)</span>:</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> tf.nn.conv2d(x, W, strides = [<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>, stride, stride, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>], padding = <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"SAME"</span>) <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">max_pool_2x2</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self,x)</span>:</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> tf.nn.max_pool(x, ksize = [<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>], strides = [<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>], padding = <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"SAME"</span>) </code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li><li style="box-sizing: border-box; padding: 0px 5px;">30</li><li style="box-sizing: border-box; padding: 0px 5px;">31</li><li style="box-sizing: border-box; padding: 0px 5px;">32</li><li style="box-sizing: border-box; padding: 0px 5px;">33</li><li style="box-sizing: border-box; padding: 0px 5px;">34</li><li style="box-sizing: border-box; padding: 0px 5px;">35</li><li style="box-sizing: border-box; padding: 0px 5px;">36</li><li style="box-sizing: border-box; padding: 0px 5px;">37</li><li style="box-sizing: border-box; padding: 0px 5px;">38</li><li style="box-sizing: border-box; padding: 0px 5px;">39</li><li style="box-sizing: border-box; padding: 0px 5px;">40</li><li style="box-sizing: border-box; padding: 0px 5px;">41</li><li style="box-sizing: border-box; padding: 0px 5px;">42</li><li style="box-sizing: border-box; padding: 0px 5px;">43</li><li style="box-sizing: border-box; padding: 0px 5px;">44</li><li style="box-sizing: border-box; padding: 0px 5px;">45</li><li style="box-sizing: border-box; padding: 0px 5px;">46</li><li style="box-sizing: border-box; padding: 0px 5px;">47</li><li style="box-sizing: border-box; padding: 0px 5px;">48</li><li style="box-sizing: border-box; padding: 0px 5px;">49</li><li style="box-sizing: border-box; padding: 0px 5px;">50</li><li style="box-sizing: border-box; padding: 0px 5px;">51</li><li style="box-sizing: border-box; padding: 0px 5px;">52</li><li style="box-sizing: border-box; padding: 0px 5px;">53</li><li style="box-sizing: border-box; padding: 0px 5px;">54</li><li style="box-sizing: border-box; padding: 0px 5px;">55</li><li style="box-sizing: border-box; padding: 0px 5px;">56</li><li style="box-sizing: border-box; padding: 0px 5px;">57</li><li style="box-sizing: border-box; padding: 0px 5px;">58</li><li style="box-sizing: border-box; padding: 0px 5px;">59</li><li style="box-sizing: border-box; padding: 0px 5px;">60</li><li style="box-sizing: border-box; padding: 0px 5px;">61</li><li style="box-sizing: border-box; padding: 0px 5px;">62</li><li style="box-sizing: border-box; padding: 0px 5px;">63</li><li style="box-sizing: border-box; padding: 0px 5px;">64</li><li style="box-sizing: border-box; padding: 0px 5px;">65</li><li style="box-sizing: border-box; padding: 0px 5px;">66</li><li style="box-sizing: border-box; padding: 0px 5px;">67</li><li style="box-sizing: border-box; padding: 0px 5px;">68</li><li style="box-sizing: border-box; padding: 0px 5px;">69</li><li style="box-sizing: border-box; padding: 0px 5px;">70</li><li style="box-sizing: border-box; padding: 0px 5px;">71</li><li style="box-sizing: border-box; padding: 0px 5px;">72</li><li style="box-sizing: border-box; padding: 0px 5px;">73</li><li style="box-sizing: border-box; padding: 0px 5px;">74</li><li style="box-sizing: border-box; padding: 0px 5px;">75</li><li style="box-sizing: border-box; padding: 0px 5px;">76</li><li style="box-sizing: border-box; padding: 0px 5px;">77</li><li style="box-sizing: border-box; padding: 0px 5px;">78</li><li style="box-sizing: border-box; padding: 0px 5px;">79</li><li style="box-sizing: border-box; padding: 0px 5px;">80</li><li style="box-sizing: border-box; padding: 0px 5px;">81</li><li style="box-sizing: border-box; padding: 0px 5px;">82</li><li style="box-sizing: border-box; padding: 0px 5px;">83</li><li style="box-sizing: border-box; padding: 0px 5px;">84</li><li style="box-sizing: border-box; padding: 0px 5px;">85</li><li style="box-sizing: border-box; padding: 0px 5px;">86</li><li style="box-sizing: border-box; padding: 0px 5px;">87</li><li style="box-sizing: border-box; padding: 0px 5px;">88</li><li style="box-sizing: border-box; padding: 0px 5px;">89</li><li style="box-sizing: border-box; padding: 0px 5px;">90</li><li style="box-sizing: border-box; padding: 0px 5px;">91</li><li style="box-sizing: border-box; padding: 0px 5px;">92</li><li style="box-sizing: border-box; padding: 0px 5px;">93</li><li style="box-sizing: border-box; padding: 0px 5px;">94</li><li style="box-sizing: border-box; padding: 0px 5px;">95</li><li style="box-sizing: border-box; padding: 0px 5px;">96</li><li style="box-sizing: border-box; padding: 0px 5px;">97</li><li style="box-sizing: border-box; padding: 0px 5px;">98</li><li style="box-sizing: border-box; padding: 0px 5px;">99</li><li style="box-sizing: border-box; padding: 0px 5px;">100</li><li style="box-sizing: border-box; padding: 0px 5px;">101</li><li style="box-sizing: border-box; padding: 0px 5px;">102</li><li style="box-sizing: border-box; padding: 0px 5px;">103</li><li style="box-sizing: border-box; padding: 0px 5px;">104</li><li style="box-sizing: border-box; padding: 0px 5px;">105</li><li style="box-sizing: border-box; padding: 0px 5px;">106</li><li style="box-sizing: border-box; padding: 0px 5px;">107</li><li style="box-sizing: border-box; padding: 0px 5px;">108</li><li style="box-sizing: border-box; padding: 0px 5px;">109</li><li style="box-sizing: border-box; padding: 0px 5px;">110</li><li style="box-sizing: border-box; padding: 0px 5px;">111</li><li style="box-sizing: border-box; padding: 0px 5px;">112</li><li style="box-sizing: border-box; padding: 0px 5px;">113</li><li style="box-sizing: border-box; padding: 0px 5px;">114</li><li style="box-sizing: border-box; padding: 0px 5px;">115</li><li style="box-sizing: border-box; padding: 0px 5px;">116</li><li style="box-sizing: border-box; padding: 0px 5px;">117</li><li style="box-sizing: border-box; padding: 0px 5px;">118</li><li style="box-sizing: border-box; padding: 0px 5px;">119</li><li style="box-sizing: border-box; padding: 0px 5px;">120</li><li style="box-sizing: border-box; padding: 0px 5px;">121</li><li style="box-sizing: border-box; padding: 0px 5px;">122</li><li style="box-sizing: border-box; padding: 0px 5px;">123</li><li style="box-sizing: border-box; padding: 0px 5px;">124</li><li style="box-sizing: border-box; padding: 0px 5px;">125</li><li style="box-sizing: border-box; padding: 0px 5px;">126</li><li style="box-sizing: border-box; padding: 0px 5px;">127</li><li style="box-sizing: border-box; padding: 0px 5px;">128</li><li style="box-sizing: border-box; padding: 0px 5px;">129</li><li style="box-sizing: border-box; padding: 0px 5px;">130</li><li style="box-sizing: border-box; padding: 0px 5px;">131</li><li style="box-sizing: border-box; padding: 0px 5px;">132</li><li style="box-sizing: border-box; padding: 0px 5px;">133</li><li style="box-sizing: border-box; padding: 0px 5px;">134</li><li style="box-sizing: border-box; padding: 0px 5px;">135</li><li style="box-sizing: border-box; padding: 0px 5px;">136</li><li style="box-sizing: border-box; padding: 0px 5px;">137</li><li style="box-sizing: border-box; padding: 0px 5px;">138</li><li style="box-sizing: border-box; padding: 0px 5px;">139</li><li style="box-sizing: border-box; padding: 0px 5px;">140</li><li style="box-sizing: border-box; padding: 0px 5px;">141</li><li style="box-sizing: border-box; padding: 0px 5px;">142</li><li style="box-sizing: border-box; padding: 0px 5px;">143</li><li style="box-sizing: border-box; padding: 0px 5px;">144</li><li style="box-sizing: border-box; padding: 0px 5px;">145</li><li style="box-sizing: border-box; padding: 0px 5px;">146</li><li style="box-sizing: border-box; padding: 0px 5px;">147</li><li style="box-sizing: border-box; padding: 0px 5px;">148</li><li style="box-sizing: border-box; padding: 0px 5px;">149</li><li style="box-sizing: border-box; padding: 0px 5px;">150</li><li style="box-sizing: border-box; padding: 0px 5px;">151</li><li style="box-sizing: border-box; padding: 0px 5px;">152</li><li style="box-sizing: border-box; padding: 0px 5px;">153</li><li style="box-sizing: border-box; padding: 0px 5px;">154</li><li style="box-sizing: border-box; padding: 0px 5px;">155</li><li style="box-sizing: border-box; padding: 0px 5px;">156</li><li style="box-sizing: border-box; padding: 0px 5px;">157</li><li style="box-sizing: border-box; padding: 0px 5px;">158</li><li style="box-sizing: border-box; padding: 0px 5px;">159</li><li style="box-sizing: border-box; padding: 0px 5px;">160</li><li style="box-sizing: border-box; padding: 0px 5px;">161</li><li style="box-sizing: border-box; padding: 0px 5px;">162</li><li style="box-sizing: border-box; padding: 0px 5px;">163</li><li style="box-sizing: border-box; padding: 0px 5px;">164</li></ul>
一共也只有160代码。
如果这个任务不使用深度学习,而是人工的从图像中找到小鸟,然后计算小鸟的轨迹,然后计算出应该怎么按键,那么代码没有好几千行是不可能的。深度学习大大减少了代码工作。
本文从代码角度对于DQN做了一定的分析,对于DQN的应用,大家可以在此基础上做各种尝试。