本文部分代码参考了:孜然v的博客Python使用pygame编写贪吃蛇小游戏、mahuateng的博客[Deep Q Learning] pytorch 从零开始建立一个简单的DQN–走迷宫游戏、以及DQN Pytorch版官方文档
由于本人还只是一个普通大学的普通本科生,代码中涉及的算法原理现在无法解释,求甚解的小伙伴可以移步上述博客~
博客中存在的代码冗余部分以及不合理的部分欢迎大家指出!!
import pygame #编写游戏用的库
import sys
import math
from collections import namedtuple
import time
from pygame.locals import *
from PIL import Image
import cv2
import numpy as np
import torchvision.transforms as T
import matplotlib.pyplot as plt
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
上述导入的都是常见的库,具体用途以及使用方法可以逐一百度
class Tanchishe:
# 定义颜色变量
def __init__(self):
self.white_colour = pygame.Color(255, 144, 20) #蛇蛇的颜色
self.red_colour = pygame.Color(1,1 , 1) #方块的颜色
self.grey_colour = pygame.Color(255, 255, 255) #背景颜色
self.count = 0
self.ftpsClock = pygame.time.Clock()
# print(ftpsClock)
# 创建一个窗口
self.gamesurface = pygame.display.set_mode((260, 260))
# 初始化贪吃蛇的起始位置
self.snakeposition = [100, 100]
# 初始化贪吃蛇的长度
self.snakelength = [[100, 100], [80, 100], [60, 100]]
# 初始化果实的位置
self.square_purpose = [100, 200]
# 初始化一个数来判断目标方块是否存在
self.square_position = 1
# 初始化方向,用来使贪吃蛇移动
self.derection = "down"
self.change_derection = self.derection
self.endgame = 0
pygame.display.flip()
self.screen_image = pygame.surfarray.array3d(pygame.display.get_surface())
# return screen_image
def gameover(self):
# 设置提示字体的格式
self.__init__() #重置游戏
return self.screen_image
# 定义主函数
def main(self,action):
count= 0
done = 0
# 数字对应的动作 [up=0,down=1,left=2,right=3]
reward = 0
pygame.init()
pygame.time.Clock()
self.pygame = pygame
pygame.display.set_caption('贪吃蛇')
# 进行游戏主循环
while True:
#根据输入的动作指令来对self.change_derection赋值
if action == 3 :
self.change_derection = "right"
if action == 2:
self.change_derection = "left"
if action == 0:
self.change_derection = "up"
if action == 1:
self.change_derection = "down"
#控制蛇蛇的运动方向
if self.change_derection == 'left' and not self.derection == 'right':
self.derection = self.change_derection
if self.change_derection == 'right' and not self.derection == 'left':
self.derection = self.change_derection
if self.change_derection == 'up' and not self.derection == 'down':
self.derection = self.change_derection
if self.change_derection == 'down' and not self.derection == 'up':
self.derection = self.change_derection
# 根据方向,改变坐标
if self.derection == 'left':
self.snakeposition[0] -= 20
if self.derection == 'right':
self.snakeposition[0] += 20
if self.derection == 'up':
self.snakeposition[1] -= 20
if self.derection == 'down':
self.snakeposition[1] += 20
# 增加蛇的长度
self.snakelength.insert(0, list(self.snakeposition))
# 判断是否吃掉果实
if self.snakeposition[0] == self.square_purpose[0] and self.snakeposition[1] == self.square_purpose[1]:
self.square_position = 0
reward = 1 #如果吃掉果实,则获得1分
self.count+=1
else:
reward=0.1
self.snakelength.pop()
# 重新生成果实
if self.square_position == 0:
# 随机生成x,y,扩大二十倍,在窗口范围内
x = random.randrange(1, 13)
y = random.randrange(1, 13)
self.square_purpose = [int(x * 20), int(y * 20)]
self.square_position = 1
# 绘制pygame显示层
self.gamesurface.fill(self.grey_colour)
for position in self.snakelength:
pygame.draw.rect(self.gamesurface, self.white_colour, Rect(position[0], position[1], 20, 20))
pygame.draw.rect(self.gamesurface, self.red_colour, Rect(self.square_purpose[0], self.square_purpose[1], 20, 20))
#若蛇蛇超出屏幕范围,或者蛇蛇碰到自己,则游戏结束,扣掉1分
if self.snakeposition[0] < 0 or self.snakeposition[0] > 260:
reward = -1
self.endgame = 1
if self.snakeposition[1] < 0 or self.snakeposition[1] > 260:
reward = -1
self.endgame =1
for snakebody in self.snakelength[1:]:
if self.snakeposition[0] == snakebody[0] and self.snakeposition[1] == snakebody[1]:
reward = -1
self.endgame =1
#获得游戏画面的3维numpy数组
screen_image = pygame.surfarray.array3d(pygame.display.get_surface())
# 控制游戏速度
self.ftpsClock.tick(100000)
pygame.display.update()
return torch.from_numpy(np.array(reward)).reshape(1,1),screen_image,self.endgame,self.count
有以下几个要点:
· 首先是关于reward的定义
reward | 分值 |
---|---|
吃到果实 | 1 |
碰到自身或者超出屏幕 | -1 |
其他 | 0.1 |
· 这段代码定义游戏屏幕大小为(260,260),定义每格大小为(20,20),这意味着蛇蛇的活动范围为13x13
· 函数的四个返回值缺一不可
这一步是为了将前面获得的screen_image由numpy转为tensor,返回screen_image时要使用.unsqueeze(0)使其从[3,40,40]升维到[1,3,40,40],以便于后续的操作。
def get_screen(screen):
#这段来自于官方文档,np.ascontiguousarray()可以使传入的数组的内存连续
screen = np.ascontiguousarray(screen, dtype=np.float32)
screen = torch.from_numpy(screen)
# print(screen.shape)
return resize(screen).unsqueeze(0).to(device)
LR = 0.001 # 学习率
EPSILON = 0.9 # 最优选择动作百分比(有0.9的几率是最大选择,还有0.1是随机选择,增加网络能学到的Q值)
GAMMA = 0.9 # 奖励递减参数(衰减作用,如果没有奖励值r=0,则衰减Q值)
N_ACTIONS = 4 #蛇蛇的动作0,1,2,3
'''后面几个变量的作用详见官方文档'''
BATCH_SIZE = 128
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
steps_done = 0
记忆库用来储存游戏进行时的四种信息:当前图像、动作、奖罚、下一个图像 (‘state’, ‘action’, ‘reward’,‘next_state’),供未来DQN网络训练
Transition = namedtuple('Transition',
('state', 'action', 'reward','next_state'))
class ReplayMemory(object):
def __init__(self, capacity):
self.capacity = capacity
self.memory = []
self.position = 0
self.counter = 0
def push(self, *args):
"""Saves a transition."""
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = Transition(*args)
self.position = (self.position + 1) % self.capacity
self.counter+=1
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
由于我们得到的screen_image的形状为[1,3,40,40],则卷积神经网络负责通过卷积层
、全连接层最终得到4个输出值。即为蛇蛇的运动方向
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.c1=nn.Conv2d(3,16,5,2,0)
self.bn1 = nn.BatchNorm2d(16)
self.c2=nn.Conv2d(16,32,5,2,0)
self.bn2 = nn.BatchNorm2d(32)
self.f1=nn.Linear(1568,100)
self.f1.weight.data.normal_(0, 0.1)
self.f2=nn.Linear(100,4)
self.f2.weight.data.normal_(0, 0.1)
def forward(self, x):
#传入的x的形状:[1,3,40,40]
x=F.relu(self.bn1(self.c1(x)))
#x的形状[1, 16, 18, 18]
x=F.relu(self.bn2(self.c2(x)))
#x的形状[1,32,7,7]
x=x.view(x.size(0),-1)
#x的形状[1,1568]
x=self.f1(x)
x=F.relu(x)
action=self.f2(x)
return action
class DQN(object):
def __init__(self):
#DQN需要使用两个神经网络
#eval为Q估计神经网络 target为Q现实神经网络
self.eval_net, self.target_net = Net().to(device), Net().to(device)
self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR) # torch 的优化器
self.loss_func = nn.MSELoss() # 误差公式
def choose_action(self, x):
x = torch.FloatTensor(x).to(device)
# 这里只输入一个 sample,x为场景
if np.random.uniform() < EPSILON: # 选最优动作
actions_value = self.eval_net.forward(x) #将场景输入Q估计神经网络
#torch.max(input,dim)返回dim最大值并且在第二个位置返回位置比如(tensor([0.6507]), tensor([2]))
action = torch.max(actions_value, 1)[1].data # 返回动作最大值
else: # 选随机动作
action = torch.from_numpy(np.array([np.random.randint(0, N_ACTIONS)])) # 比如np.random.randint(0,2)是选择1或0
# print(action)
return action.to(device)
def learn(self): #训练卷积神经网络的输出值,试蛇蛇能找到果实的位置
if len(memory) < BATCH_SIZE:
return
transitions = memory.sample(BATCH_SIZE)
batch = Transition(*zip(*transitions))
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
batch.next_state)), device=device, dtype=torch.bool)
non_final_next_states = torch.cat([s for s in batch.next_state
if s is not None])
b_s = torch.cat(batch.state).to(device)
# print(b_s.shape)
b_a = torch.cat(batch.action).reshape([128,1]).to(device)
b_r = torch.cat(batch.reward).to(device)
b_s_ = torch.cat(batch.next_state).to(device)
# 针对做过的动作b_a, 来选 q_eval 的值, (q_eval 原本有所有动作的值)
# start_time = time.time()
q_eval = self.eval_net(b_s).gather(1, b_a) # shape (batch, 1) 找到action的Q估计(关于gather使用下面有介绍)
q_next = self.target_net(b_s_).detach() # q_next 不进行反向传递误差, 所以 detach Q现实
q_target = b_r + GAMMA * q_next.max(1)[0] # shape (batch, 1) DQL核心公式
loss = self.loss_func(q_eval, q_target.float()) #计算误差
# 计算, 更新 eval net
self.optimizer.zero_grad() #
loss.backward() #反向传递
self.optimizer.step()
# end_time = time.time()
# print("耗时",end_time-start_time)
这里就要说到之前的升维操作了,如果不升维,图像的形状为[channel,x,y],升维后为[batch,channel,x,y]
如不升维,执行b_s = torch.cat(batch.state).to(device)这一步时会报错,print发现图像全部叠加到了channel那一维度,我们希望叠加到batch那一维度,所以升维操作必不可少。
if __name__ == "__main__":
#将模型移入GPU训练
device = torch.device("cuda:0")
false=0
succes = 0
env = Tanchishe()
study=1
count=0
#初始化网络
dqn = DQN()
#初始化记忆库
memory = ReplayMemory(2000)
for i in range(100000):
#不知道为什么游戏函数返回的是[y,x,channel]的格式
#tranpose负责将y轴与x轴对调
s = env.gameover().transpose(1,0,2)
s = get_screen(s.transpose(2,0,1))
print("现在是第 {} 代,蛇蛇吃下了 {} 个方块".format(i,count))
while True:
#选择动作
a = dqn.choose_action(s.cpu())
# print(a)
#将动作传入游戏中,控制蛇蛇移动
reward,screen_image,done,count = env.main(a)
# print(reward)
# print(screen_image.shape.transpose(1,0,2))
screen_image = screen_image.transpose(1,0,2)
screen_image = get_screen(screen_image.transpose(2,0,1))
# plt.imshow(screen_image.numpy().transpose(1,2,0))
# plt.show()
#将信息传入记忆库
memory.push(s, a, reward, screen_image)
#如果记忆库信息数量超过记忆库大小
if memory.counter > memory.capacity:
dqn.learn()
if done==1 or done==2: # 如果回合结束, 进入下回合
if done==1:
# print('epoch',i_episode,r,'失败')
false+=1
if done==2:
succes+=1
# print('epoch',i_episode,r,'成功')
break
s = screen_image
P.s. Python不愧是极度精简的语言,不管是感官还是视觉都是一种享受~