深度强化学习Deep Q-Network(DQN)玩CartPole游戏源码运行笔记(Pinard版本)

1. 运行环境介绍

  • NVIDIA GTX 1070
  • Ubuntu 16.04 x64
  • CUDA 8.0.61
  • cuDNN 5.1
  • Python 3.4
  • TensorFlow 1.2.0
  • gym(gym-0.12.0.tar.gz)

2. 准备

下载源码(单文件源码)并存放在自己指定的文件夹中,地址:https://github.com/ljpzzz/machinelearning/blob/master/reinforcement-learning/dqn.py


3. 运行

在自己的Python环境中,cd到源码所在的目录,运行如下命令即可:

python dqn.py

运行过程如下面的动图所示:
深度强化学习Deep Q-Network(DQN)玩CartPole游戏源码运行笔记(Pinard版本)_第1张图片

程序的输出如下所示:

episode:  0 Evaluation Average Reward: 172.2
episode:  100 Evaluation Average Reward: 9.8
episode:  200 Evaluation Average Reward: 9.6
episode:  300 Evaluation Average Reward: 9.9
episode:  400 Evaluation Average Reward: 9.9
episode:  500 Evaluation Average Reward: 145.1
episode:  600 Evaluation Average Reward: 195.2
episode:  700 Evaluation Average Reward: 165.2
episode:  800 Evaluation Average Reward: 181.7
episode:  900 Evaluation Average Reward: 169.4
episode:  1000 Evaluation Average Reward: 158.5
episode:  1100 Evaluation Average Reward: 193.6
episode:  1200 Evaluation Average Reward: 146.1
episode:  1300 Evaluation Average Reward: 122.3
episode:  1400 Evaluation Average Reward: 116.6
episode:  1500 Evaluation Average Reward: 153.6
episode:  1600 Evaluation Average Reward: 122.3
episode:  1700 Evaluation Average Reward: 191.2
episode:  1800 Evaluation Average Reward: 190.8
episode:  1900 Evaluation Average Reward: 194.6
episode:  2000 Evaluation Average Reward: 187.4
episode:  2100 Evaluation Average Reward: 196.0
episode:  2200 Evaluation Average Reward: 171.7
episode:  2300 Evaluation Average Reward: 166.3
episode:  2400 Evaluation Average Reward: 159.6
episode:  2500 Evaluation Average Reward: 168.9
episode:  2600 Evaluation Average Reward: 160.8
episode:  2700 Evaluation Average Reward: 154.9
episode:  2800 Evaluation Average Reward: 172.1
episode:  2900 Evaluation Average Reward: 192.6

可以发现,迭代次数越多,agent玩游戏的能力越强。

作者官方的源码解析见:强化学习(八)价值函数的近似表示与Deep Q-Learning


4. 可能出现的问题及解决方法

问题: 源码的ddqn.py文件第75行进行random的sample函数调用报错:TypeError: Population must be a sequence or set. For dicts, use list(d).

解决: 定位到第75行代码,将其修改为如下即可解决此问题:

minibatch = random.sample(list(self.replay_buffer), BATCH_SIZE)

你可能感兴趣的:(深度学习(Deep,learning),TensorFlow,Python,learning))