【FlappyBird小游戏】编写AI逻辑(二)——基于队列的经验重放池

本文隶属于一个完整小项目,建议读者按照顺序阅读。

本文仅仅展示最关键的代码部分,并不会列举所有代码细节,相信具备RL基础的同学理解起来没有困难。

全部的AI代码可以在【Python小游戏】用AI玩Python小游戏FlappyBird【源码】中找到开源地址。

如果本文对您有帮助,欢迎点赞支持!


文章目录

前言

第1种设计方式:基于Numpy数组

第2种设计方式:基于Python数组

第3种设计方式:基于队列


前言

书写经验重放池是Deep Rl算法的必备技术之一,常见的是基于数组的形式,本文列举3种常见的实现方式

本文不会详细介绍代码,因为太过简单,不理解的同学可以直接在评论区提问。


第1种设计方式:基于Numpy数组

class ReplayBuffer(object):
    def __init__(self, capacity,state_dims):
        self.capacity = capacity # 经验池容量大小
        self.data = np.zeros((capacity, state_dims* 2+2))  # 经验池存放的经验数据
        self.pointer = 0 # 当前指针
    def store_transition(self, s, a, r, s_):
        # 检查是否存在
        if not hasattr(self, 'pointer'):
            self.pointer = 0
        # 存储数据
        transition = np.hstack((s, [a,r], s_))  # 按行连接
        index = self.pointer % self.capacity  # 如果超过该容量则自动从头开始
        self.data[index, :] = transition
        self.pointer += 1
    def sample(self, batch_size):
        if self.capacity < self.pointer:
            batch_indexs = np.random.choice(self.capacity, size=batch_size)
        else:
            batch_indexs = np.random.choice(self.pointer, size=batch_size)
            #assert (self.pointer >= self.capacity, '经验回放池还没有被装满')
            #print('经验回放池还没有被装满就开始采样')
        return self.data[batch_indexs, :]  # 获取n个采样
    

第2种设计方式:基于Python数组

class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = int((self.position + 1) % self.capacity)  # as a ring buffer

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))  # stack for each element
        return state, action, reward, next_state, done

    def __len__(self):
        return len(self.buffer)

第3种设计方式:基于队列

本项目使用队列来进行设计,其代码更加简洁:

from collections import deque
import random
class ReplayBuffer(object):
    def __init__(self, capacity):
        self.memory_size = capacity # 容量大小
        self.num = 0 # 存放的经验数据数量
        self.data = deque() # 存放经验数据的队列

    def store_transition(self, state,action,reward,state_,terminal):
        self.data.append((state, action, reward, state_, terminal))# 添加数据
        if len(self.data) > self.memory_size:
            self.data.popleft()
            self.num -= 1
        self.num += 1

    def sample(self, batch_size):
        minibatch = random.sample(self.data, batch_size)
        return minibatch  # 获取n个采样

 

你可能感兴趣的:(Python程序设计,机器学习之强化学习,队列,算法,python,强化学习,游戏)