最近在学习了强化学习之后,在guithub上下载了一些使用不同强化学习方法的小项目,收获颇丰,于是想自己搭建一个gym环境1,直接使用项目中的一些方法去训练,希望能够加深自己的一些理解
游戏参考的github上的大佬写的代码,在这里放上源码连接:GitHub - GrayPlane123/BirdGame: 小鸟管道游戏,通过键盘或鼠标控制小鸟振翅,如果小鸟碰撞到管道或者飞到界面边缘则游戏结束
接下来是对源代码的改造,以此来搭建gym的环境。
源码是用pygame来实现游戏显示的,主要改造方面如下:
1.将键盘响应改为输入action参数判断,action维度有两个,即无操作和飞翔,分别用0和1表示;
2.去除不必要的一些游戏结束的文本显示;
3.将游戏逻辑和游戏渲染分开,游戏逻辑部分大部分放在step函数内,渲染放在render函数内,在这里我发现游戏源码是用pygame的colliderect函数实现的,因此没法将所有有关pygame的操作分离开step函数,我也懒得再另外实现,于是将此函数保留在step内。
if upRect.colliderect(self.Bird.birdRect) or downRect.colliderect(self.Bird.birdRect):
self.Bird.dead = True
reward = -1
self.Bird.status = 2
4.奖励机制上我做了些许改变,原来是每通过一个管道就加一分,这里我改为了如果掉出世界或碰到管道就会扣一分,另外我把最高分设置为50,这样可以避免验证周期过长。
在对环境进行简单搭建之后(仅仅是能用了,代码部分很乱),开始使用强化学习方法进行训练,为了方便,训练的代码大部分粘的github上的项目代码,自己只是稍微改了一下,使用的DDQN策略进行训练,价值函数使用三层全连接进行模拟,另外使用了经验回放。
在训练过程中,发现诸多问题,经过九牛二虎之力后,终于能够成功跑起来训练了,但是又发现了一个严重的问题,模型好像很难收敛,在这时我考虑到是不是state没有设置好,原本我是以小鸟的y轴坐标、管道的x坐标和两管道的y坐标这四维作为state,模型往往中间会有一段很好,但是后来分数又会下降,后来改为了小鸟离管道的距离、两个管道高度上分别离小鸟有多远这三维为state,这次效果虽然好点了,但是还是会出现分数高到一定程度后就又下降,不断反复这个过程,我不确定是我的state没有设置好还是其他原因,还在寻求解决方法中,后面会在state和奖励机制上下点功夫。
三次训练都迟迟不能收敛:
篇幅原因,只粘上搭建环境的代码,代码比较乱,如有不适,敬请见谅:
环境需要调用的类声明:
import pygame
import sys
import random
num = 0
score = 0
size = width, height = 400, 650
class Bird(object):
#初始化小鸟的状态
def __init__(self):
self.birdRect = pygame.Rect(65, 50, 40, 40)
self.birdStatus = [pygame.image.load("../pictures/0.png"), #初始状态
pygame.image.load("../pictures/1.png"), #飞行状态
pygame.image.load("../pictures/dead.png")] #失败状态
self.status = 0
self.birdX = 120
self.birdY = random.randint(10, height - 10)
self.jump = False
self.jumpSpeed = 10
self.gravity = 1
self.dead = False
#模拟小鸟飞行的动作
def birdUpdate(self):
#上升
if self.jump:
self.jumpSpeed -= 1
self.birdY -= self.jumpSpeed
#下降
else:
self.gravity += 0.1
self.birdY += self.gravity
self.birdRect[1] = self.birdY
class Pipeline(object):
def __init__(self):
self.wallx = 400
self.pineUp = pygame.image.load("../pictures/top.png")
self.pineDown = pygame.image.load("../pictures/bottom.png")
#小鸟飞行时管道相对向左移动,模拟管道移动
def updatePipeline(self):
a = self.pineDown.get_height()
b = self.pineUp.get_height()
self.wallx -= 5
score = 0
#小鸟飞越管道之后加一分
if self.wallx < -80:
score += 1
self.wallx = 400
#返回此次得分
return score
gym环境类:
import gym
from gym import spaces
import numpy as np
import pygame
from gym import spaces
import sys
import torch
import GameSet
class BirdGameEnv(gym.Env):
def __init__(self):
self.action_space = spaces.Discrete(2)
high = np.array([GameSet.height, -650, -650])
low = np.array([0, 650, 650])
self.observation_space = spaces.Box(low, high, dtype=np.int)
self.max_score = 50
self.score = 0
self.num = 0
self.done = False
self.clock = None
self.Pipeline = None
self.Bird = None
self.render_start = False
self.screen = None
self.background = None
def reset(self):
pygame.init()
self.score = 0
self.num = 0
self.done = False
self.render_start = False
self.Pipeline = GameSet.Pipeline()
self.Bird = GameSet.Bird()
updis = self.Bird.birdY + 450 - self.Pipeline.pineUp.get_height() - self.num
downdis = + 550 - self.num - self.Bird.birdY
return torch.tensor([self.Pipeline.wallx - self.Bird.birdX, updis, downdis], dtype=torch.float32)
def step(self, action):
# 用鼠标或键盘控制小鸟振翅
if action == 1:
self.Bird.jump = True
self.Bird.status = 1
self.Bird.gravity = 1
self.Bird.jumpSpeed = 10
self.Bird.birdUpdate()
# 上方管道
upRect = pygame.Rect(self.Pipeline.wallx, -450 + self.num,
self.Pipeline.pineUp.get_width(),
self.Pipeline.pineUp.get_height())
# 下方管道
downRect = pygame.Rect(self.Pipeline.wallx, 550 - self.num,
self.Pipeline.pineDown.get_width(),
self.Pipeline.pineDown.get_height())
# 移动管道
reward = self.Pipeline.updatePipeline()
if reward == 1:
self.num = np.random.randint(50,200)
# 检测小鸟与管道是否发生碰撞
if upRect.colliderect(self.Bird.birdRect) or downRect.colliderect(self.Bird.birdRect):
self.Bird.dead = True
reward = -1
self.Bird.status = 2
# 检测小鸟是否飞出界面区域(上下界)
if not 0 < self.Bird.birdRect[1] < GameSet.height:
self.Bird.dead = True
reward = -1
if self.Bird.dead == True or self.score >= self.max_score:
self.done = True
updis = self.Bird.birdY + 450 - self.Pipeline.pineUp.get_height() - self.num
downdis = 550 - self.num - self.Bird.birdY
self.score += reward
return torch.tensor([self.Pipeline.wallx - self.Bird.birdX, updis, downdis], dtype=torch.float32), reward, self.done, ''
def render(self):
if not self.render_start:
self.render_start = True
self.clock = pygame.time.Clock()
self.screen = pygame.display.set_mode((400, 650))
self.background = pygame.image.load("../pictures/background.png")
for event in pygame.event.get():
if event.type == pygame.QUIT:
sys.exit()
#每秒30次
self.clock.tick(30)
# 背景
self.screen.fill((255, 255, 255))
self.screen.blit(self.background, (0, 0))
# 更新小鸟的状态
self.screen.blit(self.Bird.birdStatus[self.Bird.status], (self.Bird.birdX, self.Bird.birdY))
# 管道
# self.screen.blit(Pipeline.pineUp, (Pipeline.wallx, -300 - randself.num)) # 上管道坐标位置
# self.screen.blit(Pipeline.pineDown, (Pipeline.wallx, 300 + randself.num)) # 下管道坐标位置
self.screen.blit(self.Pipeline.pineUp, (self.Pipeline.wallx, -450 + self.num)) # 上管道坐标位置
self.screen.blit(self.Pipeline.pineDown, (self.Pipeline.wallx, 550 - self.num)) # 下管道坐标位置
pygame.display.flip()
def close(self):
if self.screen is not None:
import pygame
pygame.display.quit()
pygame.quit()