Pytorch初步实现DQN玩贪吃蛇

Pytorch初步实现DQN玩贪吃蛇

  • 前言
  • 一.导入依赖库
    • 二.编写贪吃蛇游戏
    • 进一步处理返回的游戏图像
  • 三.一些重要的赋值
  • 四.定义记忆库
  • 五.定义强化学习网络(核心)
      • 1.定义一个卷积网络
    • 2.定义DQN网络
  • 六.最终实现

前言

本文部分代码参考了:孜然v的博客Python使用pygame编写贪吃蛇小游戏、mahuateng的博客[Deep Q Learning] pytorch 从零开始建立一个简单的DQN–走迷宫游戏、以及DQN Pytorch版官方文档

由于本人还只是一个普通大学的普通本科生,代码中涉及的算法原理现在无法解释,求甚解的小伙伴可以移步上述博客~

博客中存在的代码冗余部分以及不合理的部分欢迎大家指出!!

一.导入依赖库

import pygame   #编写游戏用的库
import sys
import math
from collections import namedtuple
import time
from pygame.locals import *
from PIL import Image
import cv2
import numpy as np
import torchvision.transforms as T
import matplotlib.pyplot as plt      
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

上述导入的都是常见的库,具体用途以及使用方法可以逐一百度

二.编写贪吃蛇游戏


class Tanchishe:
# 定义颜色变量
    def __init__(self):
        self.white_colour = pygame.Color(255, 144, 20)   #蛇蛇的颜色
        self.red_colour = pygame.Color(1,1 , 1)     #方块的颜色
        self.grey_colour = pygame.Color(255, 255, 255)  #背景颜色
        self.count = 0
        self.ftpsClock = pygame.time.Clock()
        # print(ftpsClock)
        # 创建一个窗口
        self.gamesurface = pygame.display.set_mode((260, 260)) 
        # 初始化贪吃蛇的起始位置
        self.snakeposition = [100, 100]
        # 初始化贪吃蛇的长度
        self.snakelength = [[100, 100], [80, 100], [60, 100]]
        # 初始化果实的位置
        self.square_purpose = [100, 200]
        # 初始化一个数来判断目标方块是否存在
        self.square_position = 1
        # 初始化方向,用来使贪吃蛇移动
        self.derection = "down"
        self.change_derection = self.derection
        self.endgame = 0
        pygame.display.flip()
        self.screen_image = pygame.surfarray.array3d(pygame.display.get_surface())
        # return screen_image
    def gameover(self):
        # 设置提示字体的格式
        self.__init__()  #重置游戏
        
        return self.screen_image
    # 定义主函数
    def main(self,action):
        count= 0
        done = 0
        # 数字对应的动作 [up=0,down=1,left=2,right=3]
        reward = 0
        pygame.init()
        pygame.time.Clock()
        self.pygame = pygame
        pygame.display.set_caption('贪吃蛇')
        
        # 进行游戏主循环
        while True:
           	#根据输入的动作指令来对self.change_derection赋值
            if action == 3 :
                self.change_derection = "right"
            if action == 2:
                self.change_derection = "left"
            if action == 0:
                self.change_derection = "up"
            if action == 1:
                self.change_derection = "down"
            
            #控制蛇蛇的运动方向
            if self.change_derection == 'left' and not self.derection == 'right':
                self.derection = self.change_derection
            if self.change_derection == 'right' and not self.derection == 'left':
                self.derection = self.change_derection
            if self.change_derection == 'up' and not self.derection == 'down':
                self.derection = self.change_derection
            if self.change_derection == 'down' and not self.derection == 'up':
                self.derection = self.change_derection
            # 根据方向,改变坐标
            if self.derection == 'left':
                self.snakeposition[0] -= 20
            if self.derection == 'right':
                self.snakeposition[0] += 20
            if self.derection == 'up':
                self.snakeposition[1] -= 20
            if self.derection == 'down':
                self.snakeposition[1] += 20
            # 增加蛇的长度
            self.snakelength.insert(0, list(self.snakeposition))
            # 判断是否吃掉果实
            if self.snakeposition[0] == self.square_purpose[0] and self.snakeposition[1] == self.square_purpose[1]:
                self.square_position = 0
                reward = 1     #如果吃掉果实,则获得1分
                self.count+=1
            else:
                reward=0.1   
                self.snakelength.pop()
            # 重新生成果实
            if self.square_position == 0:
                # 随机生成x,y,扩大二十倍,在窗口范围内
                x = random.randrange(1, 13)
                y = random.randrange(1, 13)
                self.square_purpose = [int(x * 20), int(y * 20)]
                self.square_position = 1
            # 绘制pygame显示层
            self.gamesurface.fill(self.grey_colour)
            for position in self.snakelength:
                pygame.draw.rect(self.gamesurface, self.white_colour, Rect(position[0], position[1], 20, 20))
                pygame.draw.rect(self.gamesurface, self.red_colour, Rect(self.square_purpose[0], self.square_purpose[1], 20, 20))
            
           #若蛇蛇超出屏幕范围,或者蛇蛇碰到自己,则游戏结束,扣掉1分
            if self.snakeposition[0] < 0 or self.snakeposition[0] > 260:
                
                reward = -1
                self.endgame = 1
            if self.snakeposition[1] < 0 or self.snakeposition[1] > 260:
                
                reward = -1
                self.endgame =1
            
            for snakebody in self.snakelength[1:]:
                if self.snakeposition[0] == snakebody[0] and self.snakeposition[1] == snakebody[1]:
                    
                    reward = -1
                    self.endgame =1 
               
            #获得游戏画面的3维numpy数组
            screen_image = pygame.surfarray.array3d(pygame.display.get_surface())
            # 控制游戏速度
            self.ftpsClock.tick(100000)
            pygame.display.update()
            
            return torch.from_numpy(np.array(reward)).reshape(1,1),screen_image,self.endgame,self.count

有以下几个要点:

· 首先是关于reward的定义

reward 分值
吃到果实 1
碰到自身或者超出屏幕 -1
其他 0.1

· 这段代码定义游戏屏幕大小为(260,260),定义每格大小为(20,20),这意味着蛇蛇的活动范围为13x13

· 函数的四个返回值缺一不可

进一步处理返回的游戏图像

这一步是为了将前面获得的screen_image由numpy转为tensor,返回screen_image时要使用.unsqueeze(0)使其从[3,40,40]升维到[1,3,40,40],以便于后续的操作。

def get_screen(screen):
    #这段来自于官方文档,np.ascontiguousarray()可以使传入的数组的内存连续
    screen = np.ascontiguousarray(screen, dtype=np.float32)  
    screen = torch.from_numpy(screen)
    # print(screen.shape)
    
    return resize(screen).unsqueeze(0).to(device)

三.一些重要的赋值

LR = 0.001                   # 学习率
EPSILON = 0.9               # 最优选择动作百分比(有0.9的几率是最大选择,还有0.1是随机选择,增加网络能学到的Q值)
GAMMA = 0.9                 # 奖励递减参数(衰减作用,如果没有奖励值r=0,则衰减Q值)
N_ACTIONS = 4  				#蛇蛇的动作0,1,2,3
'''后面几个变量的作用详见官方文档'''
BATCH_SIZE = 128
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200

steps_done = 0

四.定义记忆库

记忆库用来储存游戏进行时的四种信息:当前图像、动作、奖罚、下一个图像 (‘state’, ‘action’, ‘reward’,‘next_state’),供未来DQN网络训练

Transition = namedtuple('Transition',
                        ('state', 'action', 'reward','next_state'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0
        self.counter = 0
    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity
        self.counter+=1
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

五.定义强化学习网络(核心)

1.定义一个卷积网络

由于我们得到的screen_image的形状为[1,3,40,40],则卷积神经网络负责通过卷积层
、全连接层最终得到4个输出值。即为蛇蛇的运动方向

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()        
        
        self.c1=nn.Conv2d(3,16,5,2,0)
        self.bn1 = nn.BatchNorm2d(16)
        
        self.c2=nn.Conv2d(16,32,5,2,0)
        self.bn2 = nn.BatchNorm2d(32)
        
        self.f1=nn.Linear(1568,100)
        self.f1.weight.data.normal_(0, 0.1)
        self.f2=nn.Linear(100,4)
        self.f2.weight.data.normal_(0, 0.1)
    def forward(self, x):
        #传入的x的形状:[1,3,40,40]
        x=F.relu(self.bn1(self.c1(x)))
        #x的形状[1, 16, 18, 18]
        x=F.relu(self.bn2(self.c2(x)))
        #x的形状[1,32,7,7]
        x=x.view(x.size(0),-1) 
        #x的形状[1,1568]
        x=self.f1(x)
        x=F.relu(x)   
        action=self.f2(x)
        return action

2.定义DQN网络

class DQN(object):
    def __init__(self):
        #DQN需要使用两个神经网络  
        #eval为Q估计神经网络 target为Q现实神经网络
        self.eval_net, self.target_net = Net().to(device), Net().to(device) 
        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR) # torch 的优化器
        self.loss_func = nn.MSELoss()   # 误差公式
    def choose_action(self, x):
        
        x = torch.FloatTensor(x).to(device)
        # 这里只输入一个 sample,x为场景
        if np.random.uniform() < EPSILON:   # 选最优动作
            actions_value = self.eval_net.forward(x) #将场景输入Q估计神经网络
            #torch.max(input,dim)返回dim最大值并且在第二个位置返回位置比如(tensor([0.6507]), tensor([2]))
            action = torch.max(actions_value, 1)[1].data # 返回动作最大值
        else:   # 选随机动作
            action = torch.from_numpy(np.array([np.random.randint(0, N_ACTIONS)])) # 比如np.random.randint(0,2)是选择1或0
        
        # print(action)
        return action.to(device)
    
    def learn(self):    #训练卷积神经网络的输出值,试蛇蛇能找到果实的位置
        if len(memory) < BATCH_SIZE:
            return
        transitions = memory.sample(BATCH_SIZE) 
        batch = Transition(*zip(*transitions))
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                              batch.next_state)), device=device, dtype=torch.bool)
        non_final_next_states = torch.cat([s for s in batch.next_state
                                                    if s is not None])
        b_s = torch.cat(batch.state).to(device)
        # print(b_s.shape)
        b_a = torch.cat(batch.action).reshape([128,1]).to(device)
        b_r = torch.cat(batch.reward).to(device)
        b_s_ = torch.cat(batch.next_state).to(device)
        # 针对做过的动作b_a, 来选 q_eval 的值, (q_eval 原本有所有动作的值)
        # start_time = time.time()
        q_eval = self.eval_net(b_s).gather(1, b_a)  # shape (batch, 1) 找到action的Q估计(关于gather使用下面有介绍)
        q_next = self.target_net(b_s_).detach()     # q_next 不进行反向传递误差, 所以 detach Q现实
        q_target = b_r + GAMMA * q_next.max(1)[0]   # shape (batch, 1) DQL核心公式
        loss = self.loss_func(q_eval, q_target.float()) #计算误差
        # 计算, 更新 eval net
        self.optimizer.zero_grad() #
        loss.backward() #反向传递
        self.optimizer.step()
        # end_time = time.time()
        # print("耗时",end_time-start_time)

这里就要说到之前的升维操作了,如果不升维,图像的形状为[channel,x,y],升维后为[batch,channel,x,y]
如不升维,执行b_s = torch.cat(batch.state).to(device)这一步时会报错,print发现图像全部叠加到了channel那一维度,我们希望叠加到batch那一维度,所以升维操作必不可少。

六.最终实现

if __name__ == "__main__":
    #将模型移入GPU训练
    device = torch.device("cuda:0") 
    false=0
    succes = 0
    env = Tanchishe()
    study=1
    count=0
    #初始化网络
    dqn = DQN()
    #初始化记忆库
    memory = ReplayMemory(2000)
    for i in range(100000):
       #不知道为什么游戏函数返回的是[y,x,channel]的格式
       #tranpose负责将y轴与x轴对调 
       s = env.gameover().transpose(1,0,2)
       s = get_screen(s.transpose(2,0,1))
       print("现在是第 {} 代,蛇蛇吃下了 {} 个方块".format(i,count))
       while True:
            #选择动作
            a = dqn.choose_action(s.cpu())
            # print(a)
            #将动作传入游戏中,控制蛇蛇移动
            reward,screen_image,done,count = env.main(a)
            # print(reward)
            # print(screen_image.shape.transpose(1,0,2))
            screen_image = screen_image.transpose(1,0,2)
            screen_image = get_screen(screen_image.transpose(2,0,1))
            # plt.imshow(screen_image.numpy().transpose(1,2,0))
            # plt.show()
            #将信息传入记忆库
            memory.push(s, a, reward, screen_image)
            #如果记忆库信息数量超过记忆库大小
            if memory.counter > memory.capacity:
                dqn.learn()        
            if done==1 or done==2:    # 如果回合结束, 进入下回合
                
                if done==1:
                    # print('epoch',i_episode,r,'失败')
                    
                    false+=1
                if done==2:
                    
                    succes+=1
                    # print('epoch',i_episode,r,'成功')                
                break
            s = screen_image
        

P.s. Python不愧是极度精简的语言,不管是感官还是视觉都是一种享受~

你可能感兴趣的:(深度学习,pytorch,神经网络)