下面介绍各部分需要写的内容.
from setuptools import setup
setup(name='gym_maze',
version='0.0.1',
install_requires=['gym']
)
主要是定义包的名字, 版本号以及需要依赖的其它package.
from gym.envs.registration import register
register(
id='maze-v0',
entry_point='gym_maze.envs:Maze',
)
id即为gym.make()中需要使用的环境名字, entry_point即为指明类来源自哪里
from gym_maze.envs.maze_env import Maze
import gym
import random
import time
from gym.envs.classic_control import rendering
class Maze(gym.Env):
metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 2
}
def __init__(self):
self.states = [1, 2, 3, 4, 5,
6, 7, 8, 9, 10,
11, 12, 13, 14, 15,
16, 17, 18, 19, 20,
21, 22, 23, 24, 25]
self.actions = ['n', 'e', 's', 'w']
# 定义reward
self.rewards = dict() # 回报的数据结构为字典
self.rewards['10_s'] = 10.0
self.rewards['14_e'] = 10.0
self.rewards['20_n'] = 10.0
# 定义状态转换
self.t = dict()
self.t['1_s'] = 6
self.t['1_e'] = 2
self.t['2_w'] = 1
self.t['2_s'] = 7
self.t['2_e'] = 3
self.t['3_w'] = 2
self.t['3_s'] = 8
self.t['5_s'] = 10
self.t['6_n'] = 1
self.t['6_e'] = 7
self.t['7_w'] = 6
self.t['7_n'] = 2
self.t['7_e'] = 8
self.t['8_w'] = 7
self.t['8_n'] = 3
self.t['8_s'] = 13
self.t['10_n'] = 5
self.t['10_s'] = 15
self.t['13_n'] = 8
self.t['13_e'] = 14
self.t['13_s'] = 18
self.t['14_w'] = 13
self.t['14_e'] = 15
self.t['14_s'] = 19
self.t['15_n'] = 10
self.t['15_w'] = 14
self.t['15_s'] = 20
self.t['16_e'] = 17
self.t['16_s'] = 21
self.t['17_w'] = 16
self.t['17_e'] = 18
self.t['17_s'] = 22
self.t['18_w'] = 17
self.t['18_e'] = 19
self.t['18_n'] = 13
self.t['19_w'] = 18
self.t['19_n'] = 14
self.t['19_e'] = 20
self.t['20_w'] = 19
self.t['20_n'] = 15
self.t['21_n'] = 16
self.t['21_e'] = 22
self.t['22_w'] = 21
self.t['22_n'] = 17
# 定义终止状态
self.terminate_states = [15]
self.gamma = 0.8 # 折扣因子
self.viewer = None
self.state = None
def getTerminate_states(self):
return self.terminate_states
def getGamma(self):
return self.gamma
def getStates(self):
return self.states
def getAction(self):
return self.actions
def setStatus(self, s):
self.state = s
# 返回运动后的状态、该步的reward、是否到达终点以及调试信息(调试信息一般用{}表示)
def step(self, action):
# input action
# output next_state reward is_terminal debug
# 系统当前状态
state = self.state
if state in self.terminate_states:
return state, 0, True, {}
key = "%d_%s" % (state, action) # 将状态和动作组成字典的键值
# 状态转移
if key in self.t:
next_state = self.t[key]
else:
next_state = state
self.state = next_state
is_terminal = False
if next_state in self.terminate_states:
is_terminal = True
if key not in self.rewards:
r = -1
else:
r = self.rewards[key]
return next_state, r, is_terminal, {}
# 返回一个不在障碍物范围内的初始点
def reset(self):
s = [4, 9, 11, 12, 23, 24, 25]
self.state = self.states[int(random.random() * (len(self.states) - 1))]
while self.state in s:
self.state = self.states[int(random.random() * (len(self.states) - 1))]
return self.state
def close(self):
if self.viewer:
self.viewer.close()
self.viewer = None
def render(self, mode="human", close=False):
if close:
if self.viewer is not None:
self.viewer.close()
self.viewer = None
return
width = 80
height = 40
edge_x = 100
edge_y = 100
screen_width = 600
screen_height = 400
# 存储每个网格的中心点
x_states = [edge_x + 0.5 * width, edge_x + 1.5 * width, edge_x + 2.5 * width, edge_x + 3.5 * width,
edge_x + 4.5 * width,
edge_x + 0.5 * width, edge_x + 1.5 * width, edge_x + 2.5 * width, edge_x + 3.5 * width,
edge_x + 4.5 * width,
edge_x + 0.5 * width, edge_x + 1.5 * width, edge_x + 2.5 * width, edge_x + 3.5 * width,
edge_x + 4.5 * width,
edge_x + 0.5 * width, edge_x + 1.5 * width, edge_x + 2.5 * width, edge_x + 3.5 * width,
edge_x + 4.5 * width,
edge_x + 0.5 * width, edge_x + 1.5 * width, edge_x + 2.5 * width, edge_x + 3.5 * width,
edge_x + 4.5 * width
]
y_states = [edge_y + 4.5 * height, edge_y + 4.5 * height, edge_y + 4.5 * height, edge_y + 4.5 * height,
edge_y + 4.5 * height,
edge_y + 3.5 * height, edge_y + 3.5 * height, edge_y + 3.5 * height, edge_y + 3.5 * height,
edge_y + 3.5 * height,
edge_y + 2.5 * height, edge_y + 2.5 * height, edge_y + 2.5 * height, edge_y + 2.5 * height,
edge_y + 2.5 * height,
edge_y + 1.5 * height, edge_y + 1.5 * height, edge_y + 1.5 * height, edge_y + 1.5 * height,
edge_y + 1.5 * height,
edge_y + 0.5 * height, edge_y + 0.5 * height, edge_y + 0.5 * height, edge_y + 0.5 * height,
edge_y + 0.5 * height
]
polygon1 = rendering.make_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True)
polygon2 = rendering.make_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True)
polygon3 = rendering.make_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True)
polygon4 = rendering.make_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True)
polygon5 = rendering.make_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True)
polygon6 = rendering.make_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True)
polygon7 = rendering.make_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True)
polygon8 = rendering.make_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True)
polygon1.set_color(0, 0, 0)
polygon2.set_color(0, 0, 0)
polygon3.set_color(0, 0, 0)
polygon4.set_color(0, 0, 0)
polygon5.set_color(0, 0, 0)
polygon6.set_color(0, 0, 0)
polygon7.set_color(0, 0, 0)
polygon8.set_color(1, 0.9, 0) # 目标点
polygon1_trans = rendering.Transform(translation=(edge_x + 3 * width, edge_y + 4 * height))
polygon1.add_attr(polygon1_trans)
polygon2_trans = rendering.Transform(translation=(edge_x + 3 * width, edge_y + 3 * height))
polygon2.add_attr(polygon2_trans)
polygon3_trans = rendering.Transform(translation=(edge_x, edge_y + 2 * height))
polygon3.add_attr(polygon3_trans)
polygon4_trans = rendering.Transform(translation=(edge_x + 1 * width, edge_y + 2 * height))
polygon4.add_attr(polygon4_trans)
polygon5_trans = rendering.Transform(translation=(edge_x + 2 * width, edge_y))
polygon5.add_attr(polygon5_trans)
polygon6_trans = rendering.Transform(translation=(edge_x + 3 * width, edge_y))
polygon6.add_attr(polygon6_trans)
polygon7_trans = rendering.Transform(translation=(edge_x + 4 * width, edge_y))
polygon7.add_attr(polygon7_trans)
polygon8_trans = rendering.Transform(translation=(edge_x + 4 * width, edge_y + 2 * height))
polygon8.add_attr(polygon8_trans)
if self.viewer is None:
self.viewer = rendering.Viewer(screen_width, screen_height)
# 探索的物体
self.circle = rendering.make_circle(20)
self.circle_trans = rendering.Transform()
self.circle.add_attr(self.circle_trans)
self.circle.set_color(0.8, 0.6, 0.4)
# 添加障碍物1-7以及目标点8
self.viewer.add_geom(polygon1)
self.viewer.add_geom(polygon2)
self.viewer.add_geom(polygon3)
self.viewer.add_geom(polygon4)
self.viewer.add_geom(polygon5)
self.viewer.add_geom(polygon6)
self.viewer.add_geom(polygon7)
self.viewer.add_geom(polygon8)
self.viewer.add_geom(self.circle)
# 基于循环完成网格的绘制
for i in range(6):
self.viewer.draw_line((edge_x, edge_y + i * height), (edge_x + 5 * width, edge_y + i * height))
self.viewer.draw_line((edge_x + i * width, edge_y), (edge_x + i * width, edge_y + 5 * height))
if self.state is None:
return None
self.circle_trans.set_translation(x_states[self.state - 1], y_states[self.state - 1])
return self.viewer.render(return_rgb_array=mode == 'rgb_array')
if __name__ == '__main__':
env = Maze()
env.reset()
env.render()
time.sleep(1)
cd gym_maze # 当前只进入了第一层gym_maze, 包括第二层gym_maze和setup.py
pip install -e.
安装好后可以看看conda list里面是否包括自己安装的包
简单介绍下使用
import gym
import gym_maze # 自己创建的功能包, 注意必须添加, 否则没有加入到环境中
import time
env = gym.make('maze-v0')
env.reset()
env.render()
time.sleep(1)
env.close()
envs = [env.id for env in gym.envs.registry.all()]
for env in envs:
print(env)
自定义并注册Gym环境
Gym官网tutorial