【强化学习日志】小鸟管道游戏的gym环境搭建和DDQN训练

        最近在学习了强化学习之后,在guithub上下载了一些使用不同强化学习方法的小项目,收获颇丰,于是想自己搭建一个gym环境1,直接使用项目中的一些方法去训练,希望能够加深自己的一些理解

        游戏参考的github上的大佬写的代码,在这里放上源码连接:GitHub - GrayPlane123/BirdGame: 小鸟管道游戏,通过键盘或鼠标控制小鸟振翅,如果小鸟碰撞到管道或者飞到界面边缘则游戏结束

        接下来是对源代码的改造,以此来搭建gym的环境。

1. gym 环境搭建

         源码是用pygame来实现游戏显示的,主要改造方面如下:

        1.将键盘响应改为输入action参数判断,action维度有两个,即无操作和飞翔,分别用0和1表示;

        2.去除不必要的一些游戏结束的文本显示;

        3.将游戏逻辑和游戏渲染分开,游戏逻辑部分大部分放在step函数内,渲染放在render函数内,在这里我发现游戏源码是用pygame的colliderect函数实现的,因此没法将所有有关pygame的操作分离开step函数,我也懒得再另外实现,于是将此函数保留在step内。

        if upRect.colliderect(self.Bird.birdRect) or downRect.colliderect(self.Bird.birdRect):
            self.Bird.dead = True
            reward = -1
            self.Bird.status = 2

        4.奖励机制上我做了些许改变,原来是每通过一个管道就加一分,这里我改为了如果掉出世界或碰到管道就会扣一分,另外我把最高分设置为50,这样可以避免验证周期过长。

2. 训练阶段

        在对环境进行简单搭建之后(仅仅是能用了,代码部分很乱),开始使用强化学习方法进行训练,为了方便,训练的代码大部分粘的github上的项目代码,自己只是稍微改了一下,使用的DDQN策略进行训练,价值函数使用三层全连接进行模拟,另外使用了经验回放。

        在训练过程中,发现诸多问题,经过九牛二虎之力后,终于能够成功跑起来训练了,但是又发现了一个严重的问题,模型好像很难收敛,在这时我考虑到是不是state没有设置好,原本我是以小鸟的y轴坐标、管道的x坐标和两管道的y坐标这四维作为state,模型往往中间会有一段很好,但是后来分数又会下降,后来改为了小鸟离管道的距离、两个管道高度上分别离小鸟有多远这三维为state,这次效果虽然好点了,但是还是会出现分数高到一定程度后就又下降,不断反复这个过程,我不确定是我的state没有设置好还是其他原因,还在寻求解决方法中,后面会在state和奖励机制上下点功夫。

三次训练都迟迟不能收敛:

【强化学习日志】小鸟管道游戏的gym环境搭建和DDQN训练_第1张图片

3. 环境源码

        篇幅原因,只粘上搭建环境的代码,代码比较乱,如有不适,敬请见谅:

环境需要调用的类声明:

import pygame
import sys
import random

num = 0
score = 0
size = width, height = 400, 650

class Bird(object):

    #初始化小鸟的状态
    def __init__(self):

        self.birdRect = pygame.Rect(65, 50, 40, 40)
        self.birdStatus = [pygame.image.load("../pictures/0.png"),       #初始状态
                           pygame.image.load("../pictures/1.png"),       #飞行状态
                           pygame.image.load("../pictures/dead.png")]    #失败状态
        self.status = 0
        self.birdX = 120
        self.birdY = random.randint(10, height - 10)
        self.jump = False
        self.jumpSpeed = 10
        self.gravity = 1
        self.dead = False

    #模拟小鸟飞行的动作
    def birdUpdate(self):
        #上升
        if self.jump:
            self.jumpSpeed -= 1
            self.birdY -= self.jumpSpeed
        #下降
        else:
            self.gravity += 0.1
            self.birdY += self.gravity

        self.birdRect[1] = self.birdY


class Pipeline(object):

    def __init__(self):
        self.wallx = 400
        self.pineUp = pygame.image.load("../pictures/top.png")
        self.pineDown = pygame.image.load("../pictures/bottom.png")

    #小鸟飞行时管道相对向左移动,模拟管道移动
    def updatePipeline(self):

        a = self.pineDown.get_height()
        b = self.pineUp.get_height()


        self.wallx -= 5
        score = 0
        #小鸟飞越管道之后加一分
        if self.wallx < -80:
            score += 1
            self.wallx = 400
        #返回此次得分
        return score

gym环境类:

import gym
from gym import spaces
import numpy as np
import pygame
from gym import spaces
import sys
import torch

import GameSet

class BirdGameEnv(gym.Env):
    def __init__(self):
        self.action_space = spaces.Discrete(2)
        high = np.array([GameSet.height, -650, -650])
        low = np.array([0, 650, 650])
        self.observation_space = spaces.Box(low, high, dtype=np.int)
        self.max_score = 50

        self.score = 0
        self.num = 0
        self.done = False

        self.clock = None
        self.Pipeline = None
        self.Bird = None
        self.render_start = False
        self.screen = None
        self.background = None
    def reset(self):
        pygame.init()
        self.score = 0
        self.num = 0
        self.done = False
        self.render_start = False


        self.Pipeline = GameSet.Pipeline()
        self.Bird = GameSet.Bird()
        updis =  self.Bird.birdY + 450 - self.Pipeline.pineUp.get_height() - self.num
        downdis =  + 550 - self.num - self.Bird.birdY
        return torch.tensor([self.Pipeline.wallx - self.Bird.birdX, updis, downdis], dtype=torch.float32)

    def step(self, action):
        # 用鼠标或键盘控制小鸟振翅
        if action == 1:
            self.Bird.jump = True
            self.Bird.status = 1
            self.Bird.gravity = 1
            self.Bird.jumpSpeed = 10

        self.Bird.birdUpdate()
        # 上方管道
        upRect = pygame.Rect(self.Pipeline.wallx, -450 + self.num,
                             self.Pipeline.pineUp.get_width(),
                             self.Pipeline.pineUp.get_height())

        # 下方管道
        downRect = pygame.Rect(self.Pipeline.wallx, 550 - self.num,
                               self.Pipeline.pineDown.get_width(),
                               self.Pipeline.pineDown.get_height())

        # 移动管道
        reward = self.Pipeline.updatePipeline()
        if reward == 1:
            self.num = np.random.randint(50,200)

        # 检测小鸟与管道是否发生碰撞
        if upRect.colliderect(self.Bird.birdRect) or downRect.colliderect(self.Bird.birdRect):
            self.Bird.dead = True
            reward = -1
            self.Bird.status = 2

        # 检测小鸟是否飞出界面区域(上下界)
        if not 0 < self.Bird.birdRect[1] < GameSet.height:
            self.Bird.dead = True
            reward = -1

        if self.Bird.dead == True or self.score >= self.max_score:
            self.done = True

        updis = self.Bird.birdY + 450 - self.Pipeline.pineUp.get_height() - self.num
        downdis = 550 - self.num - self.Bird.birdY
        self.score += reward
        return torch.tensor([self.Pipeline.wallx - self.Bird.birdX, updis, downdis], dtype=torch.float32), reward, self.done, ''

    def render(self):
        if not self.render_start:
            self.render_start = True
            self.clock = pygame.time.Clock()
            self.screen = pygame.display.set_mode((400, 650))
            self.background = pygame.image.load("../pictures/background.png")

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                sys.exit()


        #每秒30次
        self.clock.tick(30)
        # 背景
        self.screen.fill((255, 255, 255))
        self.screen.blit(self.background, (0, 0))

        # 更新小鸟的状态
        self.screen.blit(self.Bird.birdStatus[self.Bird.status], (self.Bird.birdX, self.Bird.birdY))
        # 管道
        # self.screen.blit(Pipeline.pineUp, (Pipeline.wallx, -300 - randself.num))   # 上管道坐标位置
        # self.screen.blit(Pipeline.pineDown, (Pipeline.wallx, 300 + randself.num))  # 下管道坐标位置
        self.screen.blit(self.Pipeline.pineUp, (self.Pipeline.wallx, -450 + self.num))  # 上管道坐标位置
        self.screen.blit(self.Pipeline.pineDown, (self.Pipeline.wallx, 550 - self.num))  # 下管道坐标位置

        pygame.display.flip()


    def close(self):
        if self.screen is not None:
            import pygame

            pygame.display.quit()
            pygame.quit()

你可能感兴趣的:(python,人工智能,神经网络,机器学习)