深度强化学习DQN(附DQN训练Flappy Bird源代码)

1.DQN算法

关于DQN详细算法请参考:
深度强化学习DQN详解
深度强化学习—DQN
深度强化学习入门

2.DQN源代码

源代码中一共有三个主要.py文件。

BrainDQN_Nature.py和BrainDQN_NIPS.py主要定义网络结构,FlappyBirdDQN.py是训练文件,FlappyBirdDQN的源代码如下所示,可以直接运行。

# -------------------------
# Project: Deep Q-Learning on Flappy Bird
# Author: Flood Sung
# Date: 2016.3.21
# -------------------------

import cv2
import sys
sys.path.append("game/")
import wrapped_flappy_bird as game
from BrainDQN_Nature import BrainDQN
import numpy as np

# preprocess raw image to 80*80 gray image
def preprocess(observation):
	observation = cv2.cvtColor(cv2.resize(observation, (80, 80)), cv2.COLOR_BGR2GRAY)
	ret, observation = cv2.threshold(observation,1,255,cv2.THRESH_BINARY)
	return np.reshape(observation,(80,80,1))

def playFlappyBird():
	# Step 1: init BrainDQN
	actions = 2
	brain = BrainDQN(actions)
	# Step 2: init Flappy Bird Game
	flappyBird = game.GameState()
	# Step 3: play game
	# Step 3.1: obtain init state
	action0 = np.array([1,0])  # do nothing
	observation0, reward0, terminal = flappyBird.frame_step(action0)
	observation0 = cv2.cvtColor(cv2.resize(observation0, (80, 80)), cv2.COLOR_BGR2GRAY)
	ret, observation0 = cv2.threshold(observation0,1,255,cv2.THRESH_BINARY)
	brain.setInitState(observation0)

	# Step 3.2: run the game
	while 1!= 0:
		action = brain.getAction()
		nextObservation,reward,terminal = flappyBird.frame_step(action)
		nextObservation = preprocess(nextObservation)
		brain.setPerception(nextObservation,action,reward,terminal)

def main():
	playFlappyBird()

if __name__ == '__main__':
	main()
		

3.实验结果

跑了好久,大约训练了三个小时吧,这个bird终于学会飞了。
本博文所用源代码链接:https://download.csdn.net/download/qq_29462849/10776375,可以直接运行。

深度强化学习DQN(附DQN训练Flappy Bird源代码)_第1张图片
深度强化学习DQN(附DQN训练Flappy Bird源代码)_第2张图片

你可能感兴趣的:(深度学习)