在网上看到的元学习 MAML 的代码大多是跟图像相关的,强化学习这边的代码比较少。
因为自己的思路跟 MAML-RL 相关,所以打算读一些源码。
MAML 的原始代码是基于 tensorflow 的,在 Github 上找到了基于 Pytorch 源码包,学习这个包。
https://github.com/dragen1860/MAML-Pytorch-RL
./maml_rl/envs/subproc_vec_env.py
import
包import numpy as np
import multiprocessing as mp
import gym
import sys
import queue
EnvWorker()
类class EnvWorker(mp.Process):
def __init__(self, remote, env_fn, queue_, lock):
#### 类初始化,继承了mp.Process()类,应该是一个做多线程的类。结合名字应该是将线程、函数信息、队列和锁传递给了“环境工人”,那么将发出器传出的指令接受到“环境工人”中。
"""
:param remote: send/recv connection, type of Pipe
:param env_fn: construct environment function
:param queue_: global queue instance
:param lock: Every worker has a lock
"""
super(EnvWorker, self).__init__()
self.remote = remote # Pipe()
self.env = env_fn() # return a function
self.queue = queue_
self.lock = lock
self.task_id = None
self.done = False
#### 这个函数是用来进行一步啥也不干的时间步。观测信息是全0向量,都没有;奖励为0;done标记是True。
def empty_step(self):
"""
conduct a dummy step
:return:
"""
observation = np.zeros(self.env.observation_space.shape, dtype=np.float32)
reward, done = 0.0, True
return observation, reward, done, {}
#### 这个函数应该是重置环境的意思。self.lock应该是进程的锁,当锁被打开时,依次读取任务的id号码,当队列里面没有id号码的时候,执行except异常处理,self.done设置为True。当队列里面还有id号码的时候,self.done设置为False。
def try_reset(self):
"""
:return:
"""
with self.lock:
try:
self.task_id = self.queue.get(True) # block = True
self.done = (self.task_id is None)
except queue.Empty:
self.done = True
#### 如果self.done设置为True,说明队列没有任务了,观测信息就清空;否则就重置环境。
# construct empty state or get state from env.reset()
observation = np.zeros(self.env.observation_space.shape, dtype=np.float32) if self.done else self.env.reset()
return observation
#### 对线程做一些处理。
def run(self):
"""
:return:
"""
while True:
#### 应该是从管道中接受到数据和指令,类似于串口输入
command, data = self.remote.recv()
#### 如果指令是'step',说明就要在环境中采取一次时间步,获得下一时刻的观测信息、奖励和是否完成信号。如果self.done设置为True,说明队列没有任务了,就执行空步骤;反之就是将接受到的data数据,应该是动作信息,传入self.env.step()函数中,获得下一时刻的观测信息、奖励和是否完成信号。
if command == 'step':
observation, reward, done, info = (self.empty_step() if self.done else self.env.step(data))
#### 如果一个episode执行完了,self.done设置为False,这说明还需要继续执行,因此就self.try_reset()获得观测信息。
if done and (not self.done):
observation = self.try_reset()
#### 将获得的下一时刻的观测信息、奖励和是否完成信号输出到管道中。
self.remote.send((observation, reward, done, self.task_id, info))
#### 如果指令是'reset',说明只是纯粹的重置环境,那么就执行self.try_reset()将状态信息发送给管道。
elif command == 'reset':
observation = self.try_reset()
self.remote.send((observation, self.task_id))
#### 如果指令是'reset_task',也就是任务重置,那么就执行self.env.unwrapped.reset_task(data),并输出重置成功的标志True。
elif command == 'reset_task':
self.env.unwrapped.reset_task(data)
self.remote.send(True)
#### 如果指令是'close',说明进程要结束了,执行self.remote.close()。
elif command == 'close':
self.remote.close()
break
#### 如果指令是'get_spaces',也就是获得空间的观测信息,将当前环境的观测信息发出去。如果是其他指令,就报异常。
elif command == 'get_spaces':
self.remote.send((self.env.observation_space, self.env.action_space))
else:
raise NotImplementedError()
SubprocVecEnv()
类class SubprocVecEnv(gym.Env):
#### 这个类应该是创建子进程环境。self.lock应该是设置进程锁,保证多进程中只有一个进程是读写数据的。self.remotes和self.work_remotes数据的收发。用EnvWorker()类,为每个子进程构建一个小智能体。
def __init__(self, env_factorys, queue_):
"""
:param env_factorys: list of [lambda x: def p: envs.make(env_name), return p], len: num_workers
:param queue:
"""
self.lock = mp.Lock()
# remotes: all recv conn, len: 8, here duplex=True
# works_remotes: all send conn, len: 8, here duplex=True
self.remotes, self.work_remotes = zip(*[mp.Pipe() for _ in env_factorys])
# queue and lock is shared.
self.workers = [EnvWorker(remote, env_fn, queue_, self.lock)
for (remote, env_fn) in zip(self.work_remotes, env_factorys)]
#### 在for循环里面有依次使能一个智能体这样。
# start 8 processes to interact with environments.
for worker in self.workers:
worker.daemon = True
worker.start()
for remote in self.work_remotes:
remote.close()
self.waiting = False # for step_async
self.closed = False
#### 看作者的注释说,既然父进程需要跟子进程联系,那么需要用一个方式传递这些数据。在这里使用mp.Pipe()类来收发数据。将收到的数据解耦赋值给self.observation_space和self.action_space。
# Since the main process need talk to children processes, we need a way to comunicate between these.
# here we use mp.Pipe() to send/recv data.
self.remotes[0].send(('get_spaces', None))
observation_space, action_space = self.remotes[0].recv()
self.observation_space = observation_space
self.action_space = action_space
#### 等待每个子进程环境下的运行结果,输出的结果就是self.step_wait()的结果,也就是一个时间步下面的状态、奖励、是否完成的信息。
def step(self, actions):
"""
step synchronously
:param actions:
:return:
"""
self.step_async(actions)
# wait until step state overdue
return self.step_wait()
#### 将每个进程和实时动作信息打包,发送给各个子进程,然后打上self.waiting = True标签。
def step_async(self, actions):
"""
step asynchronouly
:param actions:
:return:
"""
# let each sub-process step
for remote, action in zip(self.remotes, actions):
remote.send(('step', action))
self.waiting = True
#### 收集每个远程子进程的数据,保存到results中,self.waiting设置成False,将results的内容分解出来,得到下一个时间步的观测、奖励信号、是否完成、任务号和其他信息。最后将这些观测信息又拼接起来。
def step_wait(self):
results = [remote.recv() for remote in self.remotes]
self.waiting = False
observations, rewards, dones, task_ids, infos = zip(*results)
return np.stack(observations), np.stack(rewards), np.stack(dones), task_ids, infos
#### 同步地重置环境。将重置环境的结果保存在results中。解耦合results可以得到每个任务的重置观测和任务序列号task_ids。最后整合所有的初始观测和任务序列号。
def reset(self):
"""
reset synchronously
:return:
"""
for remote in self.remotes:
remote.send(('reset', None))
results = [remote.recv() for remote in self.remotes]
observations, task_ids = zip(*results)
return np.stack(observations), task_ids
#### 重置整个任务,输出的是重置任务后的所有数据。
def reset_task(self, tasks):
for remote, task in zip(self.remotes, tasks):
remote.send(('reset_task', task))
return np.stack([remote.recv() for remote in self.remotes])
#### 关闭一系列子进程。如果已经是关闭的了,不用执行直接返回。如果是self.waiting==True,先接受每个子进程的数据,然后对每个子进程输出‘close’结束的标志,最后关闭。
def close(self):
if self.closed:
return
if self.waiting: # cope with step_async()
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(('close', None))
for worker in self.workers:
worker.join()
self.closed = True
这个类应该是创建子进程的过程,可能这样高效一些。
具体来说是异步发送指令信号,同步接受并执行,最后父进程依次收取子进程的信息并打包起来。
这里有些库涉及到了 multiprocessing
这个库,所以还需要调其他文档再理解一下。