本文通过构造经典的悬崖漫步(Cliff Walking)环境,来说明 gym 环境的自定义和使用方法
首先简单介绍悬崖漫步环境,本段引用自 《动手学强化学习》第4章
这是一个 4 x 12 的网格世界,每一个网格表示一个状态。智能体的起点是左下角的状态,目标是右下角的状态,智能体在每一个状态都可以采取 4 种动作:上、下、左、右。如果智能体采取动作后触碰到边界墙壁则状态不发生改变,否则就会相应到达下一个状态。环境中有一段悬崖,智能体掉入悬崖或到达目标状态都会结束动作并回到起点,也就是说掉入悬崖或者达到目标状态是终止状态。智能体每走一步的奖励是 −1,掉入悬崖的奖励是 −100
下面给出在我利用 gym 实现的 Cliff Walking 环境中运行 Q-Learning 的效果(8倍快放)
render_mode='rgb_array'
情况下,通过键盘手动控制 agent 运动的情况(手动交互渲染模式必须设为 rgb_array),可见agent 移动到悬崖或目标点会自动复位到起点。注意到这里没有渲染价值颜色和策略,这是因为渲染模式设置为 rgb_array
时,每次调用渲染方法 render()
会将游戏画面会转像素 ndarray 形式返回,常用于借助 CNN 进行状态观测的情况,为避免影响观测不应渲染额外内容render_mode='human'
情况下,随机采样 action 运行的情况,这时状态价值 V V V 和 policy 都是随机设置的,所以图像在闪烁gym.Env
metadata
属性,指定环境支持的渲染模式和渲染帧率observation_space
和动作空间属性 action_space
,必须是 gym.spaces
某个子类的实例reset()
,step()
,render()
,close()
四个类方法作为环境对外交互的窗口,其他类方法应设置为 _
开头的内部方法遵循以下要点
gym.Env
metadata
属性,从中指定环境支持的渲染模式(如 "human", "rgb_array", "ansi"
等)和渲染帧率。所有环境都支持渲染模式 None
,应该将其设置为 __init__
方法的缺省形参__init__
中定义环境的观测空间 observation_space
和动作空间 action_space
,它们必须是 gym.spaces
某个子类的实例。我这里将观测定义为 agent 当前位置和目标位置的二维坐标,因此观测空间可以设为元素为 gym.spaces.MultiDiscrete
的 gym.spaces.Dict
类实例代码如下
class CliffWalkingEnv(gym.Env):
def __init__(self, render_mode=None, map_size=(4,12), pix_square_size=20):
self.pix_square_size = pix_square_size
self.nrow = map_size[0]
self.ncol = map_size[1]
self.start_location = np.array([0, self.nrow-1], dtype=int)
self.target_location = np.array([self.ncol-1, self.nrow-1], dtype=int)
# 观测空间
self.observation_space = spaces.Dict(
{
"agent": spaces.MultiDiscrete([self.ncol, self.nrow]),
"target": spaces.MultiDiscrete([self.ncol, self.nrow]),
}
)
# 动作空间:上下左右+noop
self.action_space = spaces.Discrete(5)
# 每个动作对应 agent 位置的变化
self._action_to_direction = {
0: np.array([0, 0]), # noop
1: np.array([1, 0]), # right
2: np.array([0, 1]), # down
3: np.array([-1, 0]), # left
4: np.array([0, -1]), # up
}
# 渲染模式支持 'human' 或 'rgb_array'
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
# 渲染模式为 render_mode == 'human' 时用于渲染窗口的组件
self.window = None
self.clock = None
env.reset()
和 env.step()
方法中都要返回 observation
,可以设置一个内部方法进行 state 到 observation 的转换。另外,这里将二者返回的附加信息 info
定义为 agent 当前状态距离目标位置的曼哈顿距离def _get_obs(self):
return {"agent": self._agent_location, "target": self._target_location}
def _get_info(self):
return {"distance": np.linalg.norm(self._agent_location - self._target_location, ord=1)} # 附加信息定义为 agent 当前位置到 target 的曼哈顿距离
reset
方法用于启动一个新轨迹的交互过程,可以假定在调用 reset()
之前不会调用 step()
方法,同时,当 step()
方法发出 terminated
或 truncated
信号时应该调用 reset()
gym.Env
内置一个随机数生成器,一般通过形参 seed
,在 reset()
方法中,通过 super().reset(seed=seed)
实现该随机数生成器的初始化。seed
形参应该有缺省值 None
,按使用规范,创建环境实例后应立即带 seed
值调用 reset()
方法,之后 reset()
时不要再带 seed
。另外,如果第一次调用 reset()
时没有指定 seed
,会利用某个熵源(如时间戳)随机设置随机数生成器reset()
方法应该返回初始 observation 和一些辅助 info,可以使用之前实现的 _get_obs
和 _get_info
方法def reset(self, seed=None, options=None):
'''step 方法给出 terminated 或 truncated 信号后,调用 reset 启动新轨迹'''
# 通过 super 初始化并使用基类的 self.np_random 随机数生成器
super().reset(seed=seed)
# agent 置于起点,设置终点位置
self._agent_location = self.start_location.copy()
self._target_location = self.target_location.copy()
# 获取当前状态观测和附加信息
observation = self._get_obs()
info = self._get_info()
# 可以在此刷新渲染,但本例需要渲染最新策略,所以在测试时更新策略后再手动调用 render 方法
#if self.render_mode == "human":
# self._render_frame()
return observation, info
step()
方法通常包含环境的大部分逻辑。它接受一个 action,计算应用该 action 后的环境 state,并返回元组 (observation, reward, terminated, truncated, info)_get_obs
和 _get_info
方法得到,reward 根据环境定义设置、terminated信号设为 “达到目标位置”,truncated信号设为 “落下悬崖”def step(self, action):
'''环境一步转移'''
# agent 转移到执行 action 后的新位置
self._agent_location = self._state_transition(self._agent_location, action)
# 判断标识 terminated & truncated 并给出 reward
terminated = np.array_equal(self._agent_location, self._target_location)
truncated = self._agent_location[1].item() == self.nrow - 1 and self._agent_location[0].item() not in [0, self.ncol-1]
reward = -1
if terminated: reward = 0
if truncated: reward = -100
# 获取当前状态观测和附加信息
observation = self._get_obs()
info = self._get_info()
# 可以在这里刷新渲染,但我这里需要渲染最新策略,所以在测试时再手动调用 render 方法
#if self.render_mode == "human":
# self._render_frame()
return observation, reward, terminated, truncated, info
render_mode
判断是否渲染价值颜色和策略,另外,由于提供 V value 和 policy 的 RL 算法对环境状态观测进行了 HashPosition
包装(见下文 2.3 节),将原始观测中 key agent
对应的二维位置坐标拉平成一维,所以渲染时要手动处理一下改回来def render(self, state_values=None, policy=None):
if self.render_mode == "rgb_array":
return self._render_frame() # 'rgb_array' 渲染模式下画面会转换为像素 ndarray 形式返回,通常用于借助 CNN 从游戏画面提取观测向量的情况,为避免影响观测不要渲染价值颜色和策略
elif self.render_mode == "human":
self._render_frame(state_values, policy) # 'human' 渲染模式下会弹出窗口,如果不直接通过游戏画面提取状态观测,可以渲染价值颜色和策略,以便人员观察收敛情况
else:
raise False # 不支持其他渲染模式,报错
def _state_transition(self, state, action):
'''返回 agent 在 state 出执行 action 后转移到的新位置'''
direction = self._action_to_direction[action]
state += direction
state[0] = np.clip(state[0], 0, self.ncol-1).item()
state[1] = np.clip(state[1], 0, self.nrow-1).item()
return state
def _render_frame(self, state_values=None, policy=None):
pix_square_size = self.pix_square_size
canvas = pygame.Surface((self.ncol*pix_square_size, self.nrow*pix_square_size))
canvas.fill((255, 255, 255))
if self.window is None and self.render_mode == "human":
pygame.init()
pygame.display.init()
self.window = pygame.display.set_mode((self.ncol*pix_square_size, self.nrow*pix_square_size))
if self.clock is None and self.render_mode == "human":
self.clock = pygame.time.Clock()
# 背景白色
pygame.draw.rect(
canvas,
(255, 255, 255),
pygame.Rect(
(0, 0),
(pix_square_size*self.ncol, pix_square_size*self.nrow),
),
)
# 绘制远离悬崖的方格
if self.render_mode == "human" and isinstance(state_values, np.ndarray): # human 模式下渲染状态颜色
for col in range(self.ncol):
for row in range(self.nrow-1):
state_value = state_values[row][col].item()
max_value = 1 if np.abs(state_values).max() == 0 else np.abs(state_values).max()
pygame.draw.rect(
canvas,
(abs(state_value)/max_value*255, 20, 20), # 通过颜色反映 state value
pygame.Rect(
(col*pix_square_size, row*pix_square_size),
(pix_square_size-1, pix_square_size-1), # 每个状态格边长减小1,这样自动出现缝线
),
)
else: # rgb_array 模式下不渲染状态颜色
for col in range(self.ncol):
for row in range(self.nrow-1):
pygame.draw.rect(
canvas,
(150, 150, 150),
pygame.Rect(
(col*pix_square_size, row*pix_square_size),
(pix_square_size-1, pix_square_size-1),
),
)
# 绘制悬崖边最后一行方格
for col in range(self.ncol):
if col == 0:
color = (100, 100, 100) # 起点
elif col == self.ncol-1:
color = (100, 150, 100) # 终点
else:
color = (0, 0, 0) # 悬崖
pygame.draw.rect(
canvas,
color,
pygame.Rect(
(col*pix_square_size, (self.nrow-1)*pix_square_size),
(pix_square_size-1, pix_square_size-1), # 每个状态格边长减小1,这样自动出现缝线
),
)
# 绘制 agent
pygame.draw.circle(
canvas,
(0, 0, 255),
(self._agent_location + 0.5) * pix_square_size,
pix_square_size / 3,
)
# human 模式下渲染基于 Q value 的贪心策略
if self.render_mode == "human" and isinstance(policy, np.ndarray):
# 前几行正常行走区域
for col in range(self.ncol):
for row in range(self.nrow-1):
hash_position = col*self.nrow + row
actions = policy[hash_position]
for a in actions:
s_ = self._state_transition(np.array([col,row]), a)
if (s_ != np.array([col,row])).sum() != 0:
start = np.array([col*pix_square_size+0.5*pix_square_size,row*pix_square_size+0.5*pix_square_size])
end = s_*pix_square_size+0.5*pix_square_size
dot_num = 15
for i in range(dot_num):
pygame.draw.rect(
canvas,
(10, 255-i*175/dot_num, 10),
pygame.Rect(
start + (end-start) * i/dot_num,
(2,2)
),
)
# 最后一行只绘制起点策略
col, row = 0, self.nrow-1
hash_position = col*self.nrow + row
actions = policy[hash_position]
for a in actions:
s_ = self._state_transition(np.array([col,row]), a)
if (s_ != np.array([col,row])).sum() != 0:
start = np.array([col*pix_square_size+0.5*pix_square_size,row*pix_square_size+0.5*pix_square_size])
end = s_*pix_square_size+0.5*pix_square_size
dot_num = 15
for i in range(dot_num):
pygame.draw.rect(
canvas,
(10, 255-i*175/dot_num, 10),
pygame.Rect(
start + (end-start) * i/dot_num,
(2,2)
),
)
# 'human' 渲染模式下会弹出窗口
if self.render_mode == "human":
# The following line copies our drawings from `canvas` to the visible window
self.window.blit(canvas, canvas.get_rect())
pygame.event.pump()
pygame.display.update()
# We need to ensure that human-rendering occurs at the predefined framerate.
# The following line will automatically add a delay to keep the framerate stable.
self.clock.tick(self.metadata["render_fps"])
# 'rgb_array' 渲染模式下画面会转换为像素 ndarray 形式返回,适用于用 CNN 进行状态观测的情况,为避免影响观测不要渲染价值颜色和策略
else:
return np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2))
close()
方法应该关闭环境所有使用的开放资源,本例中 render_mode
可能是 "human"
,所以关闭可能打开了的窗口def close(self):
if self.window is not None:
pygame.display.quit()
pygame.quit()
env.observation_space['agent']
元素拉平成一维的包装,供后续 RL 算法使用gym.ObservationWrapper
类,然后重写 observation()
方法import gym
# 观测包装,把环境的原生二维观测转为一维的
class HashPosition(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
self.env = env
map_size = env.observation_space['agent'].nvec
self.observation_space = gym.spaces.Discrete(map_size[0]*map_size[1]) # 新的观测空间
def observation(self, obs):
return obs["agent"][0] * self.env.nrow + obs["agent"][1]
__init__.py
文件,这样 python 就会把这个目录识别为一个包,从而可以在根目录通过 import MyGymExamples
,import MyGymExamples.envs
这样的代码引入这些目录转成的包import MyGymExamples.envs
时,会按从深到浅的顺序依次执行各个目录下的 __init__.py
进行初始化。为了能从任意一层包(目录)处直接 import 其下所有子包(子目录)内定义的类,应该在各层 __init__.py
文件中 import 其所有子目录中的类。本例中三个 __init__.py
的内容如下# envs 目录下的 __init__.py
from MyGymExamples.envs.GridWorld import CliffWalkingEnv
# wrappers 目录下的 __init__.py
from MyGymExamples.wrappers.HashPosition import HashPosition
# MyGymExamples 目录下的 __init__.py
from MyGymExamples.envs.GridWorld import CliffWalkingEnv
from MyGymExamples.wrappers.HashPosition import HashPosition
from MyGymExamples import CliffWalkingEnv
from gym.utils.play import play
from gym.utils.env_checker import check_env
import pygame
map_size = (4,12)
env = CliffWalkingEnv(render_mode='rgb_array',
map_size=map_size,
pix_square_size=30) # 手动交互时渲染模式必须设置为 rgb_array
print(check_env(env.unwrapped)) # 检查 base 环境是否符合 gym 规范
env.action_space.seed(42) # 设置所有随机种子
observation, info = env.reset(seed=42)
# env.step() 后,env.render() 前的回调函数,可用来处理刚刚 timestep 中的运行信息
def palyCallback(obs_t, obs_tp1, action, rew, terminated, truncated, info):
# 非 noop 动作,打印 reward 和附加 info
if action != 0:
print(rew, info)
# key-action 映射关系
mapping = {(pygame.K_UP,): 4,
(pygame.K_DOWN,): 2,
(pygame.K_LEFT,): 3,
(pygame.K_RIGHT,): 1}
# 开始手动交互
play(env, keys_to_action=mapping, callback=palyCallback, fps=15, noop=0)
env.close()
from MyGymExamples import CliffWalkingEnv
from gym.utils.env_checker import check_env
import numpy as np
import random
map_size = (4,12)
env = CliffWalkingEnv(render_mode='human',
map_size=map_size,
pix_square_size=30) # render_mode 设置为 'human' 以渲染价值颜色和贪心策略
print(check_env(env.unwrapped)) # 检查 base 环境是否符合 gym 规范
env.action_space.seed(42) # 设置所有随机种子
observation, info = env.reset(seed=42)
for _ in range(10000):
# 随机采样 action 执行一个 timestep
observation, reward, terminated, truncated, info = env.step(env.action_space.sample())
# 随机产生状态价值和策略进行渲染
env.render(state_values=np.random.randint(0, 10, map_size),
policy=np.array([np.array(random.sample(list(range(5)), random.randint(1, 5))) for _ in range(map_size[0]*map_size[1])], dtype=object))
# 任务完成或失败,重置环境
if terminated or truncated:
print(reward, info)
observation, info = env.reset()
env.close()
CliffWalkingEnv
环境,因此无法像原生环境那样通过 gym.make()
方法定义环境实例。 此问题在第 3 节解决__init__.py
文件中,这样在 import MyGymExamples
时就会自动注册from gym.envs.registration import register
register(
id='MyGymExample/CliffWalkingEnv-v0',
entry_point='MyGymExamples.envs:CliffWalkingEnv',
max_episode_steps=300,
)
id
由三部分组成,可选的 namespace (这里是 MyGymExample)、必选的 name(这里是CliffWalkingEnv)和一个可选的 version(这里是v0)。id 用于在创建环境时指定环境,这里是 gym.make('my_gym_examples/CliffWalkingEnv-v0', render_mode='human', ...)
entry_point
指定环境类在源码中的路径max_episode_steps
指定创建环境时增加一个 TimeLimit
wrapper,轨迹超长就会触发 truncated
标志,可以通过 info["TimeLimit.truncated"]
区分 truncated
和 terminated
。类似地,还可以设置其他自带 wrapper,包括module
包里,然后通过形如 env = gym.make('module:Env-v0')
的代码创建环境。从这个角度看,上面的代码其实相当于把注册代码放到了 MyGymExample
包里,所以创建环境时也可以不写 import MyGymExample
,而是在创建环境时写 env = gym.make('MyGymExamples:MyGymExample/CliffWalkingEnv-v0',...)
。有时某些第三方代码库(比如RL算法库)只允许传入环境 id,这种方法就非常有用了,可以在不改动代码库的情况下注册环境__init__.py
文件中后,有以下三种创建环境的等价方式
from MyGymExamples import HashPosition, CliffWalkingEnv
env = CliffWalkingEnv( render_mode='rgb_array',
map_size=map_size,
pix_square_size=30)
import MyGymExamples
自动执行 MyGymExamples 目录中的 __init__.py
文件,从而执行到注册代码import MyGymExamples
env = gym.make('MyGymExamples/CliffWalkingEnv-v0',
render_mode='rgb_array',
map_size=map_size,
pix_square_size=30)
env = gym.make('MyGymExamples:MyGymExamples/CliffWalkingEnv-v0',
render_mode='rgb_array',
map_size=map_size,
pix_square_size=30)
from setuptools import setup, find_packages
setup(
name="MyGymExamples",
version="0.0.1",
install_requires=["gym==0.26.2", "pygame==2.1.0"],
packages=find_packages()
)
python setup.py sdist
,这会在根目录生成dist文件夹(包含压缩源码的分发包)和.egg-info文件夹(中间临时配置信息)pip install MyGymExamples-0.0.1.tar.gz
直接 pip 安装源码包,这会使 conda 虚拟环境的 site-packages 文件夹中出现自定义 MyGymExamples 包的全部源码和 MyGymExamples-0.0.1.dist-info 配置信息。在命令行输入 conda list
也能看到自定义包的安装信息...
mygymexamples 0.0.1 pypi_0 pypi
...
from MyGymExamples import HashPosition, CliffWalkingEnv
或者 import MyGymExamples
就好