自定义并注册gym环境

目录

  • 需要遵循的架构
    • setup.py
    • gym_maze/\__init__.py
    • gym_maze/envs/\__init__.py
    • gym_maze/envs/maze_env.py
  • 安装
  • 使用
  • 参考

需要遵循的架构

  • gym_maze/
    • setup.py
    • gym_maze/
      • _init_ .py
      • envs/
        • _init_ .py
        • maze_env.py

下面介绍各部分需要写的内容.

setup.py

from setuptools import setup

setup(name='gym_maze',
      version='0.0.1',
      install_requires=['gym']
      )

主要是定义包的名字, 版本号以及需要依赖的其它package.

gym_maze/_init_.py

from gym.envs.registration import register

register(
    id='maze-v0',
    entry_point='gym_maze.envs:Maze',
)

id即为gym.make()中需要使用的环境名字, entry_point即为指明类来源自哪里

gym_maze/envs/_init_.py

from gym_maze.envs.maze_env import Maze

gym_maze/envs/maze_env.py

import gym
import random
import time
from gym.envs.classic_control import rendering


class Maze(gym.Env):
    metadata = {
        'render.modes': ['human', 'rgb_array'],
        'video.frames_per_second': 2
    }

    def __init__(self):
        self.states = [1, 2, 3, 4, 5,
                       6, 7, 8, 9, 10,
                       11, 12, 13, 14, 15,
                       16, 17, 18, 19, 20,
                       21, 22, 23, 24, 25]

        self.actions = ['n', 'e', 's', 'w']
        # 定义reward
        self.rewards = dict()  # 回报的数据结构为字典
        self.rewards['10_s'] = 10.0
        self.rewards['14_e'] = 10.0
        self.rewards['20_n'] = 10.0

        # 定义状态转换
        self.t = dict()
        self.t['1_s'] = 6
        self.t['1_e'] = 2
        self.t['2_w'] = 1
        self.t['2_s'] = 7
        self.t['2_e'] = 3
        self.t['3_w'] = 2
        self.t['3_s'] = 8
        self.t['5_s'] = 10
        self.t['6_n'] = 1
        self.t['6_e'] = 7
        self.t['7_w'] = 6
        self.t['7_n'] = 2
        self.t['7_e'] = 8
        self.t['8_w'] = 7
        self.t['8_n'] = 3
        self.t['8_s'] = 13
        self.t['10_n'] = 5
        self.t['10_s'] = 15
        self.t['13_n'] = 8
        self.t['13_e'] = 14
        self.t['13_s'] = 18
        self.t['14_w'] = 13
        self.t['14_e'] = 15
        self.t['14_s'] = 19
        self.t['15_n'] = 10
        self.t['15_w'] = 14
        self.t['15_s'] = 20
        self.t['16_e'] = 17
        self.t['16_s'] = 21
        self.t['17_w'] = 16
        self.t['17_e'] = 18
        self.t['17_s'] = 22
        self.t['18_w'] = 17
        self.t['18_e'] = 19
        self.t['18_n'] = 13
        self.t['19_w'] = 18
        self.t['19_n'] = 14
        self.t['19_e'] = 20
        self.t['20_w'] = 19
        self.t['20_n'] = 15
        self.t['21_n'] = 16
        self.t['21_e'] = 22
        self.t['22_w'] = 21
        self.t['22_n'] = 17

        # 定义终止状态
        self.terminate_states = [15]

        self.gamma = 0.8  # 折扣因子
        self.viewer = None
        self.state = None

    def getTerminate_states(self):
        return self.terminate_states

    def getGamma(self):
        return self.gamma

    def getStates(self):
        return self.states

    def getAction(self):
        return self.actions

    def setStatus(self, s):
        self.state = s

    # 返回运动后的状态、该步的reward、是否到达终点以及调试信息(调试信息一般用{}表示)
    def step(self, action):
        # input action
        # output next_state reward is_terminal debug
        # 系统当前状态
        state = self.state
        if state in self.terminate_states:
            return state, 0, True, {}
        key = "%d_%s" % (state, action)  # 将状态和动作组成字典的键值

        # 状态转移
        if key in self.t:
            next_state = self.t[key]
        else:
            next_state = state
        self.state = next_state

        is_terminal = False

        if next_state in self.terminate_states:
            is_terminal = True

        if key not in self.rewards:
            r = -1
        else:
            r = self.rewards[key]

        return next_state, r, is_terminal, {}

    # 返回一个不在障碍物范围内的初始点
    def reset(self):
        s = [4, 9, 11, 12, 23, 24, 25]
        self.state = self.states[int(random.random() * (len(self.states) - 1))]
        while self.state in s:
            self.state = self.states[int(random.random() * (len(self.states) - 1))]
        return self.state

    def close(self):
        if self.viewer:
            self.viewer.close()
            self.viewer = None

    def render(self, mode="human", close=False):
        if close:
            if self.viewer is not None:
                self.viewer.close()
                self.viewer = None
            return

        width = 80
        height = 40
        edge_x = 100
        edge_y = 100

        screen_width = 600
        screen_height = 400
        # 存储每个网格的中心点
        x_states = [edge_x + 0.5 * width, edge_x + 1.5 * width, edge_x + 2.5 * width, edge_x + 3.5 * width,
                    edge_x + 4.5 * width,
                    edge_x + 0.5 * width, edge_x + 1.5 * width, edge_x + 2.5 * width, edge_x + 3.5 * width,
                    edge_x + 4.5 * width,
                    edge_x + 0.5 * width, edge_x + 1.5 * width, edge_x + 2.5 * width, edge_x + 3.5 * width,
                    edge_x + 4.5 * width,
                    edge_x + 0.5 * width, edge_x + 1.5 * width, edge_x + 2.5 * width, edge_x + 3.5 * width,
                    edge_x + 4.5 * width,
                    edge_x + 0.5 * width, edge_x + 1.5 * width, edge_x + 2.5 * width, edge_x + 3.5 * width,
                    edge_x + 4.5 * width
                    ]

        y_states = [edge_y + 4.5 * height, edge_y + 4.5 * height, edge_y + 4.5 * height, edge_y + 4.5 * height,
                    edge_y + 4.5 * height,
                    edge_y + 3.5 * height, edge_y + 3.5 * height, edge_y + 3.5 * height, edge_y + 3.5 * height,
                    edge_y + 3.5 * height,
                    edge_y + 2.5 * height, edge_y + 2.5 * height, edge_y + 2.5 * height, edge_y + 2.5 * height,
                    edge_y + 2.5 * height,
                    edge_y + 1.5 * height, edge_y + 1.5 * height, edge_y + 1.5 * height, edge_y + 1.5 * height,
                    edge_y + 1.5 * height,
                    edge_y + 0.5 * height, edge_y + 0.5 * height, edge_y + 0.5 * height, edge_y + 0.5 * height,
                    edge_y + 0.5 * height
                    ]

        polygon1 = rendering.make_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True)
        polygon2 = rendering.make_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True)
        polygon3 = rendering.make_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True)
        polygon4 = rendering.make_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True)
        polygon5 = rendering.make_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True)
        polygon6 = rendering.make_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True)
        polygon7 = rendering.make_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True)
        polygon8 = rendering.make_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True)

        polygon1.set_color(0, 0, 0)
        polygon2.set_color(0, 0, 0)
        polygon3.set_color(0, 0, 0)
        polygon4.set_color(0, 0, 0)
        polygon5.set_color(0, 0, 0)
        polygon6.set_color(0, 0, 0)
        polygon7.set_color(0, 0, 0)
        polygon8.set_color(1, 0.9, 0)  # 目标点

        polygon1_trans = rendering.Transform(translation=(edge_x + 3 * width, edge_y + 4 * height))
        polygon1.add_attr(polygon1_trans)
        polygon2_trans = rendering.Transform(translation=(edge_x + 3 * width, edge_y + 3 * height))
        polygon2.add_attr(polygon2_trans)
        polygon3_trans = rendering.Transform(translation=(edge_x, edge_y + 2 * height))
        polygon3.add_attr(polygon3_trans)
        polygon4_trans = rendering.Transform(translation=(edge_x + 1 * width, edge_y + 2 * height))
        polygon4.add_attr(polygon4_trans)
        polygon5_trans = rendering.Transform(translation=(edge_x + 2 * width, edge_y))
        polygon5.add_attr(polygon5_trans)
        polygon6_trans = rendering.Transform(translation=(edge_x + 3 * width, edge_y))
        polygon6.add_attr(polygon6_trans)
        polygon7_trans = rendering.Transform(translation=(edge_x + 4 * width, edge_y))
        polygon7.add_attr(polygon7_trans)
        polygon8_trans = rendering.Transform(translation=(edge_x + 4 * width, edge_y + 2 * height))
        polygon8.add_attr(polygon8_trans)

        if self.viewer is None:
            self.viewer = rendering.Viewer(screen_width, screen_height)

            # 探索的物体
            self.circle = rendering.make_circle(20)
            self.circle_trans = rendering.Transform()
            self.circle.add_attr(self.circle_trans)
            self.circle.set_color(0.8, 0.6, 0.4)

            # 添加障碍物1-7以及目标点8
            self.viewer.add_geom(polygon1)
            self.viewer.add_geom(polygon2)
            self.viewer.add_geom(polygon3)
            self.viewer.add_geom(polygon4)
            self.viewer.add_geom(polygon5)
            self.viewer.add_geom(polygon6)
            self.viewer.add_geom(polygon7)
            self.viewer.add_geom(polygon8)
            self.viewer.add_geom(self.circle)

        # 基于循环完成网格的绘制
        for i in range(6):
            self.viewer.draw_line((edge_x, edge_y + i * height), (edge_x + 5 * width, edge_y + i * height))
            self.viewer.draw_line((edge_x + i * width, edge_y), (edge_x + i * width, edge_y + 5 * height))

        if self.state is None:
            return None

        self.circle_trans.set_translation(x_states[self.state - 1], y_states[self.state - 1])
        return self.viewer.render(return_rgb_array=mode == 'rgb_array')


if __name__ == '__main__':
    env = Maze()
    env.reset()
    env.render()
    time.sleep(1)


安装

cd gym_maze # 当前只进入了第一层gym_maze, 包括第二层gym_maze和setup.py
pip install -e.
安装好后可以看看conda list里面是否包括自己安装的包

使用

简单介绍下使用

import gym
import gym_maze # 自己创建的功能包, 注意必须添加, 否则没有加入到环境中
import time

env = gym.make('maze-v0')
env.reset()
env.render()
time.sleep(1)
env.close()

envs = [env.id for env in gym.envs.registry.all()]
for env in envs:
    print(env)

参考

自定义并注册Gym环境
Gym官网tutorial

你可能感兴趣的:(python,学习,人工智能)