MAML-RL Pytorch 代码解读 (5) -- maml_rl/envs/init.py和utils.py

MAML-RL Pytorch 代码解读 (5) – maml_rl/envs/init.py和utils.py

文章目录

  • MAML-RL Pytorch 代码解读 (5) -- maml_rl/envs/init.py和utils.py
      • 基本介绍
        • 源码链接
        • 文件路径
      • `__init__.py` 文件
        • `import` 包
        • 代码段
      • `utils.py` 文件
        • `import` 包
        • `mujoco_wrapper` 函数

基本介绍

在网上看到的元学习 MAML 的代码大多是跟图像相关的,强化学习这边的代码比较少。

因为自己的思路跟 MAML-RL 相关,所以打算读一些源码。

MAML 的原始代码是基于 tensorflow 的,在 Github 上找到了基于 Pytorch 源码包,学习这个包。

源码链接

https://github.com/dragen1860/MAML-Pytorch-RL

文件路径

./maml_rl/envs/__init__.py

./maml_rl/envs/utils.py

__init__.py 文件

这个文件的主要工作是登记一下这些环境,以便调用时方便调用。

import

from gym.envs.registration import register

代码段

####  多臂赌博机环境:一共有[5, 10, 50]三款,分别就是前面名字+数字+v0版本。这个entry_point相当于一个路口,当这个被注册时,这个环境的观测、动作、动力学、采样和执行等信息具体在哪里?通过这个entry_point指向对应的类。多臂赌博机指向了"./maml_rl/envs/bandit.py"文件的BernoulliBanditEnv()类。
####  其他环境同理理解。每个环境还附加了一些特定的键值对,这个要注意!

# Bandit
# ----------------------------------------

for k in [5, 10, 50]:
	register(
		'Bandit-K{0}-v0'.format(k),
		entry_point='maml_rl.envs.bandit:BernoulliBanditEnv',
		kwargs={'k': k}
	)

#### 10个状态、5个动作,一个episode最长走10步
# TabularMDP
# ----------------------------------------

register(
	'TabularMDP-v0',
	entry_point='maml_rl.envs.mdp:TabularMDPEnv',
	kwargs={'num_states': 10, 'num_actions': 5},
	max_episode_steps=10
)

#### mujoco系列的环境,先指向了mujoco_wrapper入口函数做初步整理,最后再跳到mujoco里面的环境。
# Mujoco
# ----------------------------------------

register(
	'AntVel-v1',
	entry_point='maml_rl.envs.utils:mujoco_wrapper',
	kwargs={'entry_point': 'maml_rl.envs.mujoco.ant:AntVelEnv'},
	max_episode_steps=200
)

register(
	'AntDir-v1',
	entry_point='maml_rl.envs.utils:mujoco_wrapper',
	kwargs={'entry_point': 'maml_rl.envs.mujoco.ant:AntDirEnv'},
	max_episode_steps=200
)

register(
	'AntPos-v0',
	entry_point='maml_rl.envs.utils:mujoco_wrapper',
	kwargs={'entry_point': 'maml_rl.envs.mujoco.ant:AntPosEnv'},
	max_episode_steps=200
)

register(
	'HalfCheetahVel-v1',
	entry_point='maml_rl.envs.utils:mujoco_wrapper',
	kwargs={'entry_point': 'maml_rl.envs.mujoco.half_cheetah:HalfCheetahVelEnv'},
	max_episode_steps=200
)

register(
	'HalfCheetahDir-v1',
	entry_point='maml_rl.envs.utils:mujoco_wrapper',
	kwargs={'entry_point': 'maml_rl.envs.mujoco.half_cheetah:HalfCheetahDirEnv'},
	max_episode_steps=200
)

# 2D Navigation
# ----------------------------------------

register(
	'2DNavigation-v0',
	entry_point='maml_rl.envs.navigation:Navigation2DEnv',
	max_episode_steps=100
)

utils.py 文件

这个文件主要是对mujoco系列环境做一些处理。

import

from gym.envs.registration import load

from .normalized_env import NormalizedActionWrapper

mujoco_wrapper 函数

def mujoco_wrapper(entry_point, **kwargs):
    
    #### 通过load承接entry_point载入mujoco环境
    # Load the environment from its entry point
    env_cls = load(entry_point)
    
    #### 承接**kwargs键值对,建立环境
    env = env_cls(**kwargs)
    
    #### 将归一化到[-1,1]上的动作信息扩展成自己环境的类型
    # Normalization wrapper
    env = NormalizedActionWrapper(env)

    return env

你可能感兴趣的:(MetaRL_Notes,源码解读,pytorch,深度学习,python)