在网上看到的元学习 MAML 的代码大多是跟图像相关的,强化学习这边的代码比较少。
因为自己的思路跟 MAML-RL 相关,所以打算读一些源码。
MAML 的原始代码是基于 tensorflow 的,在 Github 上找到了基于 Pytorch 源码包,学习这个包。
https://github.com/dragen1860/MAML-Pytorch-RL
./maml_rl/envs/__init__.py
./maml_rl/envs/utils.py
__init__.py
文件这个文件的主要工作是登记一下这些环境,以便调用时方便调用。
import
包from gym.envs.registration import register
#### 多臂赌博机环境:一共有[5, 10, 50]三款,分别就是前面名字+数字+v0版本。这个entry_point相当于一个路口,当这个被注册时,这个环境的观测、动作、动力学、采样和执行等信息具体在哪里?通过这个entry_point指向对应的类。多臂赌博机指向了"./maml_rl/envs/bandit.py"文件的BernoulliBanditEnv()类。
#### 其他环境同理理解。每个环境还附加了一些特定的键值对,这个要注意!
# Bandit
# ----------------------------------------
for k in [5, 10, 50]:
register(
'Bandit-K{0}-v0'.format(k),
entry_point='maml_rl.envs.bandit:BernoulliBanditEnv',
kwargs={'k': k}
)
#### 10个状态、5个动作,一个episode最长走10步
# TabularMDP
# ----------------------------------------
register(
'TabularMDP-v0',
entry_point='maml_rl.envs.mdp:TabularMDPEnv',
kwargs={'num_states': 10, 'num_actions': 5},
max_episode_steps=10
)
#### mujoco系列的环境,先指向了mujoco_wrapper入口函数做初步整理,最后再跳到mujoco里面的环境。
# Mujoco
# ----------------------------------------
register(
'AntVel-v1',
entry_point='maml_rl.envs.utils:mujoco_wrapper',
kwargs={'entry_point': 'maml_rl.envs.mujoco.ant:AntVelEnv'},
max_episode_steps=200
)
register(
'AntDir-v1',
entry_point='maml_rl.envs.utils:mujoco_wrapper',
kwargs={'entry_point': 'maml_rl.envs.mujoco.ant:AntDirEnv'},
max_episode_steps=200
)
register(
'AntPos-v0',
entry_point='maml_rl.envs.utils:mujoco_wrapper',
kwargs={'entry_point': 'maml_rl.envs.mujoco.ant:AntPosEnv'},
max_episode_steps=200
)
register(
'HalfCheetahVel-v1',
entry_point='maml_rl.envs.utils:mujoco_wrapper',
kwargs={'entry_point': 'maml_rl.envs.mujoco.half_cheetah:HalfCheetahVelEnv'},
max_episode_steps=200
)
register(
'HalfCheetahDir-v1',
entry_point='maml_rl.envs.utils:mujoco_wrapper',
kwargs={'entry_point': 'maml_rl.envs.mujoco.half_cheetah:HalfCheetahDirEnv'},
max_episode_steps=200
)
# 2D Navigation
# ----------------------------------------
register(
'2DNavigation-v0',
entry_point='maml_rl.envs.navigation:Navigation2DEnv',
max_episode_steps=100
)
utils.py
文件这个文件主要是对mujoco系列环境做一些处理。
import
包from gym.envs.registration import load
from .normalized_env import NormalizedActionWrapper
mujoco_wrapper
函数def mujoco_wrapper(entry_point, **kwargs):
#### 通过load承接entry_point载入mujoco环境
# Load the environment from its entry point
env_cls = load(entry_point)
#### 承接**kwargs键值对,建立环境
env = env_cls(**kwargs)
#### 将归一化到[-1,1]上的动作信息扩展成自己环境的类型
# Normalization wrapper
env = NormalizedActionWrapper(env)
return env