MAML-RL Pytorch 代码解读 (10) -- maml_rl/envs/subproc_vec_env.py

MAML-RL Pytorch 代码解读 (10) – maml_rl/envs/subproc_vec_env.py

文章目录

  • MAML-RL Pytorch 代码解读 (10) -- maml_rl/envs/subproc_vec_env.py
      • 基本介绍
        • 源码链接
        • 文件路径
      • `import` 包
      • `EnvWorker()` 类
      • `SubprocVecEnv()` 类
      • 总结

基本介绍

在网上看到的元学习 MAML 的代码大多是跟图像相关的,强化学习这边的代码比较少。

因为自己的思路跟 MAML-RL 相关,所以打算读一些源码。

MAML 的原始代码是基于 tensorflow 的,在 Github 上找到了基于 Pytorch 源码包,学习这个包。

源码链接

https://github.com/dragen1860/MAML-Pytorch-RL

文件路径

./maml_rl/envs/subproc_vec_env.py

import

import numpy as np
import multiprocessing as mp
import gym
import sys

import queue

EnvWorker()

class EnvWorker(mp.Process):

	def __init__(self, remote, env_fn, queue_, lock):
        
        #### 类初始化,继承了mp.Process()类,应该是一个做多线程的类。结合名字应该是将线程、函数信息、队列和锁传递给了“环境工人”,那么将发出器传出的指令接受到“环境工人”中。
		"""

		:param remote: send/recv connection, type of Pipe
		:param env_fn: construct environment function
		:param queue_: global queue instance
		:param lock: Every worker has a lock
		"""
		super(EnvWorker, self).__init__()

		self.remote = remote # Pipe()
		self.env = env_fn() # return a function
		self.queue = queue_
		self.lock = lock
		self.task_id = None
		self.done = False

    #### 这个函数是用来进行一步啥也不干的时间步。观测信息是全0向量,都没有;奖励为0;done标记是True。
	def empty_step(self):
		"""
		conduct a dummy step
		:return:
		"""
		observation = np.zeros(self.env.observation_space.shape, dtype=np.float32)
		reward, done = 0.0, True

		return observation, reward, done, {}

    #### 这个函数应该是重置环境的意思。self.lock应该是进程的锁,当锁被打开时,依次读取任务的id号码,当队列里面没有id号码的时候,执行except异常处理,self.done设置为True。当队列里面还有id号码的时候,self.done设置为False。
	def try_reset(self):
		"""

		:return:
		"""
		with self.lock:
			try:
				self.task_id = self.queue.get(True) # block = True
				self.done = (self.task_id is None)
			except queue.Empty:
				self.done = True

        #### 如果self.done设置为True,说明队列没有任务了,观测信息就清空;否则就重置环境。
		# construct empty state or get state from env.reset()
		observation = np.zeros(self.env.observation_space.shape, dtype=np.float32) if self.done else self.env.reset()

		return observation

    #### 对线程做一些处理。
	def run(self):
		"""

		:return:
		"""
		while True:
            
            #### 应该是从管道中接受到数据和指令,类似于串口输入
			command, data = self.remote.recv()

            #### 如果指令是'step',说明就要在环境中采取一次时间步,获得下一时刻的观测信息、奖励和是否完成信号。如果self.done设置为True,说明队列没有任务了,就执行空步骤;反之就是将接受到的data数据,应该是动作信息,传入self.env.step()函数中,获得下一时刻的观测信息、奖励和是否完成信号。
			if command == 'step':
				observation, reward, done, info = (self.empty_step() if self.done else self.env.step(data))
                
                #### 如果一个episode执行完了,self.done设置为False,这说明还需要继续执行,因此就self.try_reset()获得观测信息。
				if done and (not self.done):
					observation = self.try_reset()
                    
                #### 将获得的下一时刻的观测信息、奖励和是否完成信号输出到管道中。
				self.remote.send((observation, reward, done, self.task_id, info))

            #### 如果指令是'reset',说明只是纯粹的重置环境,那么就执行self.try_reset()将状态信息发送给管道。
			elif command == 'reset':
				observation = self.try_reset()
				self.remote.send((observation, self.task_id))
                
            #### 如果指令是'reset_task',也就是任务重置,那么就执行self.env.unwrapped.reset_task(data),并输出重置成功的标志True。
			elif command == 'reset_task':
				self.env.unwrapped.reset_task(data)
				self.remote.send(True)
                
            #### 如果指令是'close',说明进程要结束了,执行self.remote.close()。
			elif command == 'close':
				self.remote.close()
				break
                
            #### 如果指令是'get_spaces',也就是获得空间的观测信息,将当前环境的观测信息发出去。如果是其他指令,就报异常。
			elif command == 'get_spaces':
				self.remote.send((self.env.observation_space, self.env.action_space))
			else:
				raise NotImplementedError()

SubprocVecEnv()

class SubprocVecEnv(gym.Env):

    #### 这个类应该是创建子进程环境。self.lock应该是设置进程锁,保证多进程中只有一个进程是读写数据的。self.remotes和self.work_remotes数据的收发。用EnvWorker()类,为每个子进程构建一个小智能体。
	def __init__(self, env_factorys, queue_):
		"""

		:param env_factorys: list of [lambda x: def p: envs.make(env_name), return p], len: num_workers
		:param queue:
		"""
		self.lock = mp.Lock()
		# remotes: all recv conn, len: 8, here duplex=True
		# works_remotes: all send conn, len: 8, here duplex=True
		self.remotes, self.work_remotes = zip(*[mp.Pipe() for _ in env_factorys])

		# queue and lock is shared.
		self.workers = [EnvWorker(remote, env_fn, queue_, self.lock)
		                    for (remote, env_fn) in zip(self.work_remotes, env_factorys)]
		
        #### 在for循环里面有依次使能一个智能体这样。
        # start 8 processes to interact with environments.
		for worker in self.workers:
			worker.daemon = True
			worker.start()
		for remote in self.work_remotes:
			remote.close()

		self.waiting = False # for step_async
		self.closed = False

        #### 看作者的注释说,既然父进程需要跟子进程联系,那么需要用一个方式传递这些数据。在这里使用mp.Pipe()类来收发数据。将收到的数据解耦赋值给self.observation_space和self.action_space。
		# Since the main process need talk to children processes, we need a way to comunicate between these.
		# here we use mp.Pipe() to send/recv data.
		self.remotes[0].send(('get_spaces', None))
		observation_space, action_space = self.remotes[0].recv()
		self.observation_space = observation_space
		self.action_space = action_space

    #### 等待每个子进程环境下的运行结果,输出的结果就是self.step_wait()的结果,也就是一个时间步下面的状态、奖励、是否完成的信息。
	def step(self, actions):
		"""
		step synchronously
		:param actions:
		:return:
		"""
		self.step_async(actions)
		# wait until step state overdue
		return self.step_wait()

    #### 将每个进程和实时动作信息打包,发送给各个子进程,然后打上self.waiting = True标签。
	def step_async(self, actions):
		"""
		step asynchronouly
		:param actions:
		:return:
		"""
		# let each sub-process step
		for remote, action in zip(self.remotes, actions):
			remote.send(('step', action))
		self.waiting = True

    #### 收集每个远程子进程的数据,保存到results中,self.waiting设置成False,将results的内容分解出来,得到下一个时间步的观测、奖励信号、是否完成、任务号和其他信息。最后将这些观测信息又拼接起来。
	def step_wait(self):
		results = [remote.recv() for remote in self.remotes]
		self.waiting = False
		observations, rewards, dones, task_ids, infos = zip(*results)
		return np.stack(observations), np.stack(rewards), np.stack(dones), task_ids, infos

    #### 同步地重置环境。将重置环境的结果保存在results中。解耦合results可以得到每个任务的重置观测和任务序列号task_ids。最后整合所有的初始观测和任务序列号。
	def reset(self):
		"""
		reset synchronously
		:return:
		"""
		for remote in self.remotes:
			remote.send(('reset', None))
		results = [remote.recv() for remote in self.remotes]
		observations, task_ids = zip(*results)
		return np.stack(observations), task_ids
    
    #### 重置整个任务,输出的是重置任务后的所有数据。
	def reset_task(self, tasks):
		for remote, task in zip(self.remotes, tasks):
			remote.send(('reset_task', task))
		return np.stack([remote.recv() for remote in self.remotes])

    #### 关闭一系列子进程。如果已经是关闭的了,不用执行直接返回。如果是self.waiting==True,先接受每个子进程的数据,然后对每个子进程输出‘close’结束的标志,最后关闭。
	def close(self):
		if self.closed:
			return
		if self.waiting: # cope with step_async()
			for remote in self.remotes:
				remote.recv()
		for remote in self.remotes:
			remote.send(('close', None))
		for worker in self.workers:
			worker.join()
		self.closed = True

总结

这个类应该是创建子进程的过程,可能这样高效一些。

具体来说是异步发送指令信号,同步接受并执行,最后父进程依次收取子进程的信息并打包起来。

这里有些库涉及到了 multiprocessing 这个库,所以还需要调其他文档再理解一下。

你可能感兴趣的:(源码解读,MetaRL_Notes,pytorch,python,深度学习)