gym-0.26.1
pygame-2.1.2
自定义环境 GridWolrdEnv
教程参考 官网自定义环境 ,我把一些可能有疑惑的地方讲解下。
首先整体文件结构, 这里省略了wrappers
gym-examples/
main.py # 这个是测试自定义的环境
setup.py
gym_examples/
__init__.py
envs/
__init__.py
grid_world.py
先讲几个基础知识
- init.py 的作用
最主要的作用是: 将所在的目录标记为 Python 包的一部分。
在 Python 中,一个包是一个包含模块(即 .py 文件)的目录,
而 init.py 文件表明这个目录可以被视为一个包,允许从这个目录导入模块或其他包。- class里以 _ 开头的变量,说明是私有变量,以 _ 开头方法被视为私有方法。(默认的规定,但不强制)
- 实例的变量的初始化可以不在 __init__函数里,比如在这里有些变量就是 在 reset 函数里初始化。
grid_world.py
原版的英文注释已经很清楚了,所以我们这里就是沿用就好了
import gym
from gym import spaces
import pygame
import numpy as np
class GridWorldEnv(gym.Env):
metadata = {"render_modes": ["human", "rgb_array"], "render_fps":4}
def __init__(self, render_mode=None, size=5):
super().__init__()
self.size = size # The size of the square grid
self.window_size = 512 # The size of the PyGame window
# Observations are dictionaries with the agent's and the target's location.
# Each location is encoded as an element of {0, ..., `size`}^2, i.e. MultiDiscrete([size, size]).
self.observation_space = spaces.Dict(
{
"agent": spaces.Box(0, size - 1, shape=(2,), dtype=int),
"target": spaces.Box(0, size - 1, shape=(2,), dtype=int)
}
)
# We have 4 actions, corresponding to "right", "up", "left", "down"
self.action_space = spaces.Discrete(4)
"""
The following dictionary maps abstract actions from `self.action_space` to
the direction we will walk in if that action is taken.
I.e. 0 corresponds to "right", 1 to "up" etc.
"""
self._action_to_direction = {
0: np.array([1, 0]),
1: np.array([0, 1]),
2: np.array([-1, 0]),
3: np.array([0, -1])
}
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
"""
If human-rendering is used, `self.window` will be a reference
to the window that we draw to. `self.clock` will be a clock that is used
to ensure that the environment is rendered at the correct framerate in
human-mode. They will remain `None` until human-mode is used for the
first time.
"""
self.window = None
self.clock = None
def _get_obs(self):
return {"agent": self._agent_location, "target": self._target_location}
def _get_info(self):
return {"distance": np.linalg.norm(self._agent_location - self._target_location, ord=1)}
def reset(self, seed=None, options=None):
# We need the following line to seed self.np_random
super().reset(seed=seed)
# Choose the agent's location uniformly at random
self._agent_location = self.np_random.integers(0, self.size, size=2, dtype=int)
# We will sample the target's location randomly until it does not coincide with the agent's location
self._target_location = self._agent_location
while np.array_equal(self._target_location, self._agent_location):
self._target_location = self.np_random.integers(
0, self.size, size=2, dtype=int
)
observation = self._get_obs()
info = self._get_info()
if self.render_mode == "human":
self._render_frame()
return observation, info
def step(self, action):
# Map the action (element of {0,1,2,3}) to the direction we walk in
direction = self._action_to_direction[action]
# We use `np.clip` to make sure we don't leave the grid
self._agent_location = np.clip(
self._agent_location + direction, 0, self.size - 1
)
# An episode is done iff the agent has reached the target
terminated = np.array_equal(self._agent_location, self._target_location)
reward = 1 if terminated else 0
observation = self._get_obs()
info = self._get_info()
if self.render_mode == "human":
self._render_frame()
# truncated = False
return observation, reward, terminated, False, info
def render(self):
if self.render_mode == "rgb_array":
return self._render_frame()
def _render_frame(self):
if self.window is None and self.render_mode == "human":
pygame.init()
pygame.display.init()
self.window = pygame.display.set_mode((self.window_size, self.window_size))
if self.clock is None and self.render_mode == "human":
self.clock = pygame.time.Clock()
canvas = pygame.Surface((self.window_size, self.window_size))
canvas.fill((255, 255, 255))
pix_square_size = (
self.window_size / self.size
) # The size of a single grid square in pixels
# First we draw the target
pygame.draw.rect(
canvas,
(255, 0, 0),
pygame.Rect(
pix_square_size * self._target_location,
(pix_square_size, pix_square_size),
)
)
# Now we draw the agent
pygame.draw.circle(
canvas,
(0, 0, 255),
(self._agent_location + 0.5) * pix_square_size,
pix_square_size / 3,
)
# Finally, add some gridlines
for x in range(self.size + 1):
pygame.draw.line(
canvas,
0,
(0, pix_square_size * x),
(self.window_size, pix_square_size * x),
width=3
)
pygame.draw.line(
canvas,
0,
(pix_square_size * x, 0),
(pix_square_size * x, self.window_size),
width=3
)
if self.render_mode == "human":
# The following line copies our drawings from `canvas` to the visible window
self.window.blit(canvas, canvas.get_rect())
pygame.event.pump()
pygame.display.update()
# We need to ensure that human-rendering occurs at the predefined framerate.
# The following line will automatically add a delay to keep the framerate stable.
self.clock.tick(self.metadata["render_fps"])
else: # rgb_array
return np.transpose(
np.array(pygame.surfarray.pixels3d(canvas)),axes=(1, 0, 2)
)
def close(self):
if self.window is not None:
pygame.display.quit()
pygame.quit()
同envs
目录下的__init__.py
from gym_examples.envs.grid_world import GridWorldEnv
与envs
同级别的__init__.py
这里是必需要通过
register
先注册环境的
from gym.envs.registration import register
register(
id='gym_examples/GridWorld-v0', # 可自定义,但是要唯一,不要与现有的有冲突
entry_point='gym_examples.envs:GridWorldEnv', # 这个是根据包的路径和类名定义的
max_episode_steps=300,
)
最外层的setup.py
主要的作用
- 定义包的元数据包括 包名和版本号。
- 管理依赖。
- 如果其他人想要使用你的 gym_examples 包,他们只需要克隆你的代码库,并在包的根目录下运行 pip install .。这会自动安装 gym_examples 包以及指定版本的 gym 和 pygame。
所以本地开发测试的话 不用
setup.py
也没有问题,它主要是负责定义和管理包的分发和安装。
from setuptools import setup
setup(
name="gym_examples",
version="0.0.1",
install_requires=["gym==0.26.1", "pygame==2.1.2"],
)
测试的 main.py
import gym
import gym_examples # 这个就是之前定义的包
env = gym.make('gym_examples/GridWorld-v0', render_mode="human")
observation, info = env.reset()
done, truncated = False, False
while not done and not truncated:
action = env.action_space.sample()
observation, reward, done, truncated, info = env.step(action)
env.close()