本文隶属于一个完整小项目,建议读者按照顺序阅读。
本文仅仅展示最关键的代码部分,并不会列举所有代码细节,相信具备RL基础的同学理解起来没有困难。
全部的AI代码可以在【Python小游戏】用AI玩Python小游戏FlappyBird【源码】中找到开源地址。
如果本文对您有帮助,欢迎点赞支持!
前言
第1种设计方式:基于Numpy数组
第2种设计方式:基于Python数组
第3种设计方式:基于队列
书写经验重放池是Deep Rl算法的必备技术之一,常见的是基于数组的形式,本文列举3种常见的实现方式。
本文不会详细介绍代码,因为太过简单,不理解的同学可以直接在评论区提问。
class ReplayBuffer(object):
def __init__(self, capacity,state_dims):
self.capacity = capacity # 经验池容量大小
self.data = np.zeros((capacity, state_dims* 2+2)) # 经验池存放的经验数据
self.pointer = 0 # 当前指针
def store_transition(self, s, a, r, s_):
# 检查是否存在
if not hasattr(self, 'pointer'):
self.pointer = 0
# 存储数据
transition = np.hstack((s, [a,r], s_)) # 按行连接
index = self.pointer % self.capacity # 如果超过该容量则自动从头开始
self.data[index, :] = transition
self.pointer += 1
def sample(self, batch_size):
if self.capacity < self.pointer:
batch_indexs = np.random.choice(self.capacity, size=batch_size)
else:
batch_indexs = np.random.choice(self.pointer, size=batch_size)
#assert (self.pointer >= self.capacity, '经验回放池还没有被装满')
#print('经验回放池还没有被装满就开始采样')
return self.data[batch_indexs, :] # 获取n个采样
class ReplayBuffer:
def __init__(self, capacity):
self.capacity = capacity
self.buffer = []
self.position = 0
def push(self, state, action, reward, next_state, done):
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward, next_state, done)
self.position = int((self.position + 1) % self.capacity) # as a ring buffer
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
state, action, reward, next_state, done = map(np.stack, zip(*batch)) # stack for each element
return state, action, reward, next_state, done
def __len__(self):
return len(self.buffer)
本项目使用队列来进行设计,其代码更加简洁:
from collections import deque
import random
class ReplayBuffer(object):
def __init__(self, capacity):
self.memory_size = capacity # 容量大小
self.num = 0 # 存放的经验数据数量
self.data = deque() # 存放经验数据的队列
def store_transition(self, state,action,reward,state_,terminal):
self.data.append((state, action, reward, state_, terminal))# 添加数据
if len(self.data) > self.memory_size:
self.data.popleft()
self.num -= 1
self.num += 1
def sample(self, batch_size):
minibatch = random.sample(self.data, batch_size)
return minibatch # 获取n个采样