Stable-baselines3的基本原理与使用-1

Stable-baselines3的基本原理与使用-1

  • 一、极简例子
  • 二、从环境的基类入手
    • 2.1 空间类型的基类gym.spaces.Space
    • 2.2 单个环境的基类gym.Env
    • 2.3 多个环境的基类VecEnv
    • 2.4 串行环境DummyVecEnv(VecEnv)
    • 2.5 并行环境SubprocVecEnv(VecEnv)
  • 三、预处理与特征提取
  • 四、强化模型的基类
    • 4.1 BaseModel(nn.Module)
    • 4.2 BasePolicy(BaseModel)
    • 4.3 ActorCriticPolicy(BasePolicy)
    • 4.4 ContinuousCritic(BaseModel)
    • 4.5 总结:RL模型基类的抽象逻辑
  • 五、强化存储的基类
    • 5.1 BaseBuffer
    • 5.2 RolloutBuffer(BaseBuffer)
    • 5.3 ReplayBuffer(BaseBuffer)
  • 六、强化算法的基类BaseAlgorithm

目的:学习开源库对强化学习的设计与封装逻辑,本质是阅读源码的笔记,实战见下一篇文章
基础如下:
Stable-baselines github
Stable-baselines3 文档
自定义的强化学习环境镜像

一、极简例子

import gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env

#定义环境、DQN agent
env = gym.make('CartPole-v1')
#如果需要并行环境
#env = make_vec_env(env_id, n_envs=num_cpu, seed=0, vec_env_cls=SubprocVecEnv)
model = DQN('MlpPolicy', env, verbose=1)

#agent的学习
model.learn(total_timesteps=int(2e5))
#agent内参数的保存,在当前目录下多了一个dqn_cartpole.zip文件
model.save("dqn_cartpole")
del model
#加载保存在dqn_cartpole.zip中的agent
model = DQN.load("dqn_cartpole", env=env)
#评估
mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)
#可视化
obs = env.reset()
for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, rewards, dones, info = env.step(action)
    env.render()

其中DummyVecEnv是“串行”的vectorized环境,而SubprocVecEnv是真正“并行”的vectorized环境.

API非常简洁,而强化算法是一个复杂度较高的系统,很多时候是需要定制化处理的,所以易用性带来了与之对应的抽象程度

二、从环境的基类入手

2.1 空间类型的基类gym.spaces.Space

目的:理解清楚空间类型的抽象类Box, Discrete, Dict,阅读gym.spaces.Space中的源码

  • Q1: 如何抽象输入的多样性?比如从“值的角度”看有连续的vector 或 离散的枚举值,从“维度的角度”看有vector、tensor、image,它们混杂(nested)在一起如既有与velocity相关的连续值vector,又有与位置相关的离散值image,还有one-hot vectory的原始输入,如何解决?
  • A1:用gym.spaces.Space这个类作为基类,从中衍生出Box、Discrete、MultiBinary、MultiDiscrete四种空间类型,举个例子
  1. Space:记录了数据的维度Shape,值的类型dtype,主要提供了类方法sample以支持在空间Space中采样
  2. Box:任意shape的连续空间, 3 × 4 3\times 4 3×4的MatrixBox(low=-1.0, high=2.0, shape=(3, 4), dtype=np.float32),也可以是vector
  3. Discrete:维度为1,且有n个枚举值的空间,如 n = 5 n=5 n=5的枚举空间Discrete(5),具体的枚举值为0,1,2,3,4
  4. MultiBinary:任意shape的,值只能为0,1的空间,如MultiBinary((3, 2)).sample()得到array([[0, 1], [1, 0], [1, 1]], dtype=int8)
  5. MultiDiscrete:任意shape,每个维度有自己枚举值的空间,如 5 × 2 × 2 5\times 2\times 2 5×2×2的枚举空间MultiDiscrete(( 5, 2, 2 )),第一维度有0, 1 ,2, 3, 4这5个具体的枚举值,第二、三维度具体枚举值为0, 1

比如说:如果是一个 64 × 64 64\times 64 64×64的image,每个pixel取值范围 [ 0 , 255 ] [0, 255] [0,255],那么可以选择MultiDiscrete ( ( 256 , 256 , . . . , 256 ) ⏟ 64 × 64 ) \Big(\underbrace{(256,256,...,256)}_{64\times 64}\Big) (64×64 (256,256,...,256)),但这很不合理。因为每个枚举值应该有对应的具体意义,所以可以将 [ 0 , 255 ] [0, 255] [0,255] 先预处理normalize 到 [ 0 , 1 ] [0,1] [0,1]看成连续值,那么就可以用Box(low=0, high=1, shape=(64,64),dtype=np.float32)表示了

  • 补充A1:为了处理image、连续值vector、枚举值的混杂(nested)原始输入,有Dict这种空间类型,例子如下:
# 内部用了OrderDict来记住key的插入顺序
spaces.Dict({
        'sensors':  spaces.Dict({
            'position': spaces.Box(low=-100, high=100, shape=(3,)),
            'velocity': spaces.Box(low=-1, high=1, shape=(3,)),
            'front_cam': spaces.Tuple((
                spaces.Box(low=0, high=1, shape=(10, 10, 3)),
                spaces.Box(low=0, high=1, shape=(10, 10, 3))
            )),
            'rear_cam': spaces.Box(low=0, high=1, shape=(10, 10, 3)),
        }),
        'ext_controller': spaces.MultiDiscrete((5, 2, 2)),
        'inner_state':spaces.Dict({
            'charge': spaces.Discrete(100),
            'system_checks': spaces.MultiBinary(10),
            'job_status': spaces.Dict({
                'task': spaces.Discrete(5),
                'progress': spaces.Box(low=0, high=100, shape=()),
            })
        })
    })
  • 小总结:有了空间类型的基本表示(Box,Discrete,Dict),就可以定义observation_space和action_space了

2.2 单个环境的基类gym.Env

目的:理解清楚单个环境gym.Env这个抽象类,阅读gym.core的源码

  • Q2:什么叫一个环境?主要行为是什么?
  • A2-0:核心是输入当前状态和动作值 ( s , a ) (s,a) (s,a),输出下一状态值、奖励值 s ′ , r s',r s,r
  • A2-1:因此环境的核心数据对象有observation_space, action_space,reward_range这三个空间,主要行为step(s, a)->s',r
  • A2-2:其余的reset、close、render、seed只是用来控制环境状态的辅助行为

2.3 多个环境的基类VecEnv

目的:理解vectorised环境即VecEnv抽象类的特别之处,阅读stable_baseline3.vec_env.base_vec_env的源码

  • Q3:VecEnv为啥要单独拿出来抽象?不能让gym.Env衍生出vectorised Env吗?
  • A3:从agent角度为主,若想并行环境,问题主要出在输入类型上。单个环境的输入是obs,多个环境的输入是(num_envs, obs),输出奖励从reward变成(num_envs, reward)
  • 因此,VecEnv的数据对象主要多了self.num_envs,行为从处理单个IO扩展到多个IO,从原本的step(obs, action)变成两步step_async(actions), step_wait(). 让所有环境执行actions,然后wait所有环境执行完毕,再集体返回 num_envs_obs, num_envs_rewards

2.4 串行环境DummyVecEnv(VecEnv)

如果一个环境的主要行为step(s, a)-> s', r很快,如’CartPole-v1’,那么适合用DummyVecEnv来“并行”训练

  • 小总结VecEnv:数据对象有num_envs, observation_space, action_space, reward_range,核心行为是step(),处理的是单个IO到多个IO的类型转换
  • DummyVecEnv:新增vectorized环境实体的数据对象self.envs = list(env0, env1,...),step_wait()的行为逻辑是在self.envs中串行执行的

2.5 并行环境SubprocVecEnv(VecEnv)

如果一个环境的主要行为step(s, a)-> s', r很慢,使得进程切换的开销 < step(s,a)的时间,那么适合用SubprocEnv来并行训练,阅读stable_baselines3.vec_env.sub_proc_env的源码

  • SubprocVecEnv:需要为每一个环境创建一个进程,然后由主进程进行管理,而不像DummyVecEnv那样用list()来存储多个环境

主要逻辑:

  1. 默认用forkserver的方式,起一个资源管理进程ctx
  2. 建立n个环境对应的进程 e 0 , . . . , e n e_0,...,e_n e0,...,en,在ctx与 e i e_i ei之间建立管道pipe的两个连接Connect对象,进行数据交换
  3. 通过管道之间进行数据通信(send(),recv()),用step_async e i e_i ei进程的环境发送命令与数据,并设置等待标识self.waiting = True
  4. 等所有子进程里的环境step完成,再发送给ctx进程,最后在ctx进程中将各环境的数据重组成 (num_envs, data)的形式

其余的是辅助类的环境wrapper,比如统计episode return、episode length的vec_monitor类,对observation和reward进行正则化的vec_normalize类

三、预处理与特征提取

  • 目的:主要对环境内的数据对象进行预处理后,通过神经网络进行特征提取,最后输出对应的特征
  • 阅读stable_baselines3.common.torch_layers
  1. 预处理原始输入 observation,比如环境可以直接返回np.ndarray的image,值为 [ 0 , 255 ] [0, 255] [0,255],那么需要就预处理到 [ 0 , 1 ] [0, 1] [0,1];或者环境内空间是Discrete(5)返回的是枚举值4,那么一般需要进行one-hot变成特征 [ 0 , 0 , 0 , 0 , 1 ] [0, 0, 0, 0, 1] [0,0,0,0,1],得到预处理后的原始特征
  2. 原始特征,通过神经网络的特征提取即Feature Extractor,形成网络表征,最基础的是MlpExtractor来处理vector,CnnExtractor来处理image,CombinedExtrator来处理nested input
  3. 网络表征,再通过Distribution层来建模相应量的分布,如策略分布-actor、V值-critic等

这一部分,stable-baseline3选择抽象仅局限于原始输入observation的网络结构,即Q值网络(输入是observation和action)并没被纳入到torch_layers中

class BaseFeaturesExtractor(nn.Module):
    """
    Base class that represents a features extractor.

    :param observation_space:
    :param features_dim: Number of features extracted.
    """
    def __init__(self, observation_space: gym.Space, features_dim: int = 0):
        super(BaseFeaturesExtractor, self).__init__()
        assert features_dim > 0
        self._observation_space = observation_space
        self._features_dim = features_dim

四、强化模型的基类

4.1 BaseModel(nn.Module)

包含完整网络结构,有feature extractor、optimizier、网络层参数等数据对象(nn.Module),阅读源码stable_baseline3.common.policies

  1. 数据对象attribute:self.observation_space,self.action_space,self.features_extractor, self.optimizer
  2. 核心类方法 = 主要行为 = instance method:obs_to_tensorsave & loadforwarddevice

BaseModel的骨架(feature extractor)有了,cpu与gpu之间的数据转换方法obs_to_tensor也有了,输出的是网络表征,即action logits

4.2 BasePolicy(BaseModel)

在基础骨架上,加入一些对网络结构的定制化策略,如网络每一层参数的初始化方法init_weights、如学习率的衰减策略_dummy_schedule等,但最后核心的行为是:如何把网络直接输出的action logits映射成环境内有意义的action值

  1. 数据对象attribute:self.squash_output,如果action_logits通过了tanh,即被映射到 [ − 1 , 1 ] [-1, 1] [1,1],那么需要映射回[action_space.low, action_space.high]这种有意义的动作值上
  2. 核心行为:predict将环境的obs变成obs_tensor,经过feature extractor后得到action_tensor,最终形成能直接在环境中执行的、有意义的action_array

4.3 ActorCriticPolicy(BasePolicy)

按照具体配置(feature extractor的类型与结构、网络层具体的初始化方法与激活函数、学习率具体的schedule、action distribution层的具体设置),调用instance method _build(lr_schedule)来创建这些实体

  1. 数据对象attribute:基础预处理器self.feature_extractor,经过特征提取器self.mlp_extractor输出的网络表征,根据action_space的类型带有distribution层的self.action_net,额外线性层的self.value_net,对这四个具体实体的网络层进行初始化,最后是选择定制的self.optimizer (根据这五个核心成员,可知该类是结构为MLP的A2C、PPO等on-policy算法的抽象类)
  2. 核心行为:最核心的是输入obs,输出actions,values及log_prob,即forward(obs, deterministic)-> actions, values ,log_prob,用于前向训练并记录梯度图;
    • 输入 ( s , a ) (s,a) (s,a),输出 V ( s ) V(s) V(s)的values、 log ⁡ π ( a ∣ s ) \log\pi(a|s) logπ(as)的log_prob、以及 π ( ⋅ ∣ s ) \pi(\cdot|s) π(s)的entropy,即方法evaluate_actions(),用于评估状态-动作对
    • 输入 s s s,输出 π ( ⋅ ∣ s ) \pi(\cdot|s) π(s),关于动作的分布,即方法get_distribution(),用于获取动作分布
    • 输入 s s s,输出 V ( s ) V(s) V(s),关于该状态的价值,即方法predict_values(),用于获取状态价值
    • 输入 s s s,输出 a a a,即方法_predict(),用于从分布中采样获取具体动作
    • 继承自BaseModel的方法predict(),输出能直接作用于环境的具体动作,用于与环境进行交互
  • ActorCriticCnnPolicy不过是将feature_extractor变成cnn以提取obs为image时的情况,而不是ActorCriticPolicy中的mlp结构
  • MultiInputActorCriticPolicy是将feature_extractor变成CombinedExtractor以提取obs为nested输入的情况

4.4 ContinuousCritic(BaseModel)

上述BasePolicy、ActorCriticPolicy的抽象目标是policy的直接建模 π ϕ ( ⋅ ∣ s ) \pi_\phi(\cdot|s) πϕ(s),具体为仅局限输入为s的情况,但并没有采用隐式构建Q值网络 Q ψ ( s , a ) Q_\psi(s,a) Qψ(s,a)来隐式建模策略,此处的Continuous为将输入扩充为 ( s , a ) (s,a) (s,a)的情况。

  • Q1: DQN属于哪一抽象类?ActorCriticPolicy类还是ContinuousCritic类?
  • A1: ActorCritic类,因为DQN是输入只有状态s
  1. 数据对象attribute:利用BaseModel中的self.feature_extractor ,新增的有 n n n个critics的self.q_networks,还能选择这n个critics是否share feature_extractor
  2. 核心行为forward,输入s,a,输出n个q值,一般用于前向训练

4.5 总结:RL模型基类的抽象逻辑

  1. 从环境得到的原始输入obs,经过预处理得到原始特征obs_feature,经过feature_extractor得到网络表征features
  2. Feature extractor可以是mlp extractor,也可以是cnn extractor,还可以是combined extractor
  3. 将环境的抽象observation_space&action_space,与self.feature_extractorself.optimizizer组成基类BaseModel (!!!)
  4. 如果将网络表征features当作是动作表征,那么添加与环境交互功能的predict()方法,并对齐动作表征与环境具体动作,组成基类BasePolicy
  5. 如果将网络表征features不仅当作动作表征,而且还视为状态价值的表征,便衍生出了BasePolicy的子类ActorCriticPolicy
  6. 如果将网络表征features当作是q值的表征,具体化起行为,从BaseModel衍生出了ContinuousCritic子类

五、强化存储的基类

存储的设计要点:

  1. 站在全局的角度上,即存储空间的维度与环境数量有关,如(buffer_size, num_envs, obs_shape)
  2. 如有必要,存储的是normalized后的样本

目的:定义对“经验experience”进行存储的空间buffer,存储结构主要分为RolloutBuffer和ReplayBuffer,存储的基本单元可在子类中定义,并提供对该存储空间的“存”“取”“采样”等操作方法。 阅读stable_baselines3.commom.buffers的源码

5.1 BaseBuffer

  1. 数据对象attribute:根据self.observation_space以及self.action_space来决定obs和action的维度,定义一些关于buffer的状态量如full、size等
  2. 基类方法:对这个buffer的基本操作:增、删、查、改(CRUD)+ 核心行为(采样)

5.2 RolloutBuffer(BaseBuffer)

主要用于on-policy方法,兼容并行环境即(num_rollout_steps, num_envs),特殊行为是需要记录return、values,计算advantage,所以需要与value network有所交互

  1. 数据对象attribute,存储的元素如下所示:
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.generator_ready = False
  1. 核心行为:计算return和advantage,有许多不同的计算方法,如有需要,在这里进行修改
  2. 增删查改的行为:add(obs,action,reward,episode_start,value,log_prob)增、get(batch_size)查、compute_return_and_advantages(last_values, dones)

5.3 ReplayBuffer(BaseBuffer)

主要用于off-policy方法,目前只支持一个环境,不支持多个环境

  1. 存储空间维度是统一的(self.buffer_size, self.n_env, self.obs_shape)
  2. 基本存储单元(obs, action, reward, next_obs, dones, timeouts)
  3. 基本单元的定义在stable_baselines3.common.type_alias中
  4. 只有add和sample的行为被重载了,并且assert n_envs==1

要点记录:环境返回的dones中既包含真正结束的done=1,也包含由于timeout的done=1,因此为了区分真正的timeout,可从环境返回的info中取出因timeout导致的done=1的情况info.get("TimeLimit.truncated", False)

六、强化算法的基类BaseAlgorithm

模型基类BaseModel及其子类抽象出的是一个智能体的决策结构,存储基类BaseBuffer及其子类抽象出的是一个智能体的存储结构,那么"算法"则是定义了智能体与环境之间的交互基本流程。该流程包括,1. 环境的场景处理(env_wrapper) 2. 智能体基本的决策(predict)3. 智能体的学习过程(learn)、学习策略(schedule)等。 最终在算法基类BaseAlgorithm中加入一些便于我们监控智能体行为的观测(callback)、评估(evaluation)过程

  • _wrap_env(env)创建环境
  1. 环境的基类有两种,一是gym.Env,二是stable_baselinse.vec_env.base_vec_env.VecEnv,一律转化为VecEnv进行处理
  2. 环境的wrapper最主要有三种,Monitor_wrapper用于监控episode return&length,VecTransposeImage用于改变以image为输入的维度,VecNormalize用于对obs和reward进行正则化,这些都被写在_wrap_env(env)的方法中
  3. 环境四要素observation_space, action_space, num_envs, reward_range
  • _setup_model(): 创建具体的networks,buffers,optimizer这类实体
  • _setup_lr_schedule() : 设定学习率的衰减策略
  • _setup_learn():初始化训练需要用到的变量如self.action_noise, self._total_timestamps
  • _init_callback(): 创建一个记录评估过程的callback
  • learn():主要的学习过程,返回一个trained model
  • predict():得到能直接与环境交互的动作

你可能感兴趣的:(强化学习)