MAPPO 的开源代码库:https://github.com/marlbenchmark/on-policy
MAPPO 的主要实现在 onpolicy 中实现,接下来逐一对 MAPPO 算法代码文件进行解读。
首先是基本的 actor-critic 架构,在文件 r_mappo/algorithm/r_actor_critic.py
中实现。
import torch
import torch.nn as nn
from onpolicy.algorithms.utils.util import init, check
from onpolicy.algorithms.utils.cnn import CNNBase
from onpolicy.algorithms.utils.mlp import MLPBase
from onpolicy.algorithms.utils.rnn import RNNLayer
from onpolicy.algorithms.utils.act import ACTLayer
from onpolicy.algorithms.utils.popart import PopArt
from onpolicy.utils.util import get_shape_from_obs_space
class R_Actor(nn.Module):
"""
MAPPO 的 Actor 网络类实现。输出给定观察的动作。
:param args: (argparse.Namespace) 模型相关参;
:param obs_space: (gym.Space) 观测空间;
:param action_space: (gym.Space) 动作空间;
:param device: (torch.device) gpu / cpu
"""
def __init__(self, args, obs_space, action_space, device=torch.device("cpu")):
super(R_Actor, self).__init__()
self.hidden_size = args.hidden_size
self._gain = args.gain
self._use_orthogonal = args.use_orthogonal
self._use_policy_active_masks = args.use_policy_active_masks
self._use_naive_recurrent_policy = args.use_naive_recurrent_policy
self._use_recurrent_policy = args.use_recurrent_policy
self._recurrent_N = args.recurrent_N
self.tpdv = dict(dtype=torch.float32, device=device)
obs_shape = get_shape_from_obs_space(obs_space)
base = CNNBase if len(obs_shape) == 3 else MLPBase
self.base = base(args, obs_shape)
if self._use_naive_recurrent_policy or self._use_recurrent_policy:
self.rnn = RNNLayer(self.hidden_size, self.hidden_size, self._recurrent_N, self._use_orthogonal)
self.act = ACTLayer(action_space, self.hidden_size, self._use_orthogonal, self._gain)
self.to(device)
def forward(self, obs, rnn_states, masks, available_actions=None, deterministic=False):
"""
根据给定的输入,计算动作
:param obs: (np.ndarray / torch.Tensor) 网络的观测输入
:param rnn_states: (np.ndarray / torch.Tensor) 如果是 RNN 网络,则为 RNN 的隐藏状态
:param masks: (np.ndarray / torch.Tensor) mask tensor 表示隐藏状态是否应该重新初始化为零
:param available_actions: (np.ndarray / torch.Tensor) 表示智能体可用的动作
(若为 None,则所有动作均可用)
:param deterministic: (bool) 是否从动作分布中采样,或者返回模式
:return actions: (torch.Tensor) 要执行的动作
:return action_log_probs: (torch.Tensor) log probabilities of taken actions.
:return rnn_states: (torch.Tensor) 更新 RNN 的隐藏状态
"""
obs = check(obs).to(**self.tpdv)
rnn_states = check(rnn_states).to(**self.tpdv)
masks = check(masks).to(**self.tpdv)
if available_actions is not None:
available_actions = check(available_actions).to(**self.tpdv)
## 对观测进行特征提取(actor_features),有两种 Base:CNN 或 MLP
actor_features = self.base(obs)
## 如果使用的是 RNN
if self._use_naive_recurrent_policy or self._use_recurrent_policy:
actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks)
## 给出动作
actions, action_log_probs = self.act(actor_features, available_actions, deterministic)
return actions, action_log_probs, rnn_states
def evaluate_actions(self, obs, rnn_states, action, masks, available_actions=None, active_masks=None):
"""
计算给定动作的对数概率和熵
:param obs: (torch.Tensor) 网络的观测输入
:param action: (torch.Tensor) 要计算熵和对数概率的动作
:param rnn_states: (torch.Tensor) 如果是 RNN,则为 RNN 的隐藏状态
:param masks: (torch.Tensor) 掩码张量,表示隐藏状态是否应该重新初始化为零。
:param available_actions: (torch.Tensor) 表示智能体有哪些可用的动作
(如果为 None, 则所有动作均可用)
:param active_masks: (torch.Tensor) 表示一个智能体是活动的还是无效的
:return action_log_probs: (torch.Tensor) log probabilities of the input actions.
:return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.
"""
obs = check(obs).to(**self.tpdv)
rnn_states = check(rnn_states).to(**self.tpdv)
action = check(action).to(**self.tpdv)
masks = check(masks).to(**self.tpdv)
if available_actions is not None:
available_actions = check(available_actions).to(**self.tpdv)
if active_masks is not None:
active_masks = check(active_masks).to(**self.tpdv)
## 提出 actor 特征
actor_features = self.base(obs)
## 使用 RNN 的情况
if self._use_naive_recurrent_policy or self._use_recurrent_policy:
actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks)
## 调用 act 中实现的 evaluate_actions 计算 action_log_probs 和 dist_entropy
action_log_probs, dist_entropy = self.act.evaluate_actions(actor_features, action, available_actions, active_masks = active_masks if self._use_policy_active_masks else None)
return action_log_probs, dist_entropy
class R_Critic(nn.Module):
"""
MAPPO 的 Critic 网络,给定集中输入或者局部的观测输出值函数预测。
ß
:param args: (argparse.Namespace) arguments containing relevant model information.
:param cent_obs_space: (gym.Space) (centralized) 观测空间.
:param device: (torch.device) specifies the device to run on (cpu/gpu).
"""
def __init__(self, args, cent_obs_space, device=torch.device("cpu")):
super(R_Critic, self).__init__()
self.hidden_size = args.hidden_size
self._use_orthogonal = args.use_orthogonal
self._use_naive_recurrent_policy = args.use_naive_recurrent_policy
self._use_recurrent_policy = args.use_recurrent_policy
self._recurrent_N = args.recurrent_N
self._use_popart = args.use_popart
self.tpdv = dict(dtype=torch.float32, device=device)
init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][self._use_orthogonal]
cent_obs_shape = get_shape_from_obs_space(cent_obs_space)
base = CNNBase if len(cent_obs_shape) == 3 else MLPBase
self.base = base(args, cent_obs_shape)
if self._use_naive_recurrent_policy or self._use_recurrent_policy:
self.rnn = RNNLayer(self.hidden_size, self.hidden_size, self._recurrent_N, self._use_orthogonal)
def init_(m):
return init(m, init_method, lambda x: nn.init.constant_(x, 0))
if self._use_popart:
self.v_out = init_(PopArt(self.hidden_size, 1, device=device))
else:
self.v_out = init_(nn.Linear(self.hidden_size, 1))
self.to(device)
def forward(self, cent_obs, rnn_states, masks):
"""
根据给定的输入计算动作价值
:param cent_obs: (np.ndarray / torch.Tensor) observation inputs into network.
:param rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN.
:param masks: (np.ndarray / torch.Tensor) mask tensor denoting if RNN states should be reinitialized to zeros.
:return values: (torch.Tensor) value function predictions.
:return rnn_states: (torch.Tensor) updated RNN hidden states.
"""
cent_obs = check(cent_obs).to(**self.tpdv)
rnn_states = check(rnn_states).to(**self.tpdv)
masks = check(masks).to(**self.tpdv)
critic_features = self.base(cent_obs)
if self._use_naive_recurrent_policy or self._use_recurrent_policy:
critic_features, rnn_states = self.rnn(critic_features, rnn_states, masks)
values = self.v_out(critic_features)
return values, rnn_states
随后是 MAPPO 的策略类,其实现在 r_mappo/algorithm/rMAPPOPolicy.py
import torch
from onpolicy.algorithms.r_mappo.algorithm.r_actor_critic import R_Actor, R_Critic
from onpolicy.utils.util import update_linear_schedule
class R_MAPPOPolicy:
"""
MAPPO 的策略类,包装了 actor 网络和 critic 网络来计算动作和值函数预测。
:param args: (argparse.Namespace) 包含相关模型和策略信息的参数;
:param obs_space: (gym.Space) 观测空间
:param cent_obs_space: (gym.Space) 值函数输入空(MAPPO 集中输入,IPPO 分散输入);
:param action_space: (gym.Space) 动作空间
:param device: (torch.device) 选择设备 gpu / cpu
"""
def __init__(self, args, obs_space, cent_obs_space, act_space, device=torch.device("cpu")):
self.device = device
self.lr = args.lr
self.critic_lr = args.critic_lr
self.opti_eps = args.opti_eps
self.weight_decay = args.weight_decay
self.obs_space = obs_space
## cent_obs_space = N * obs_space
self.share_obs_space = cent_obs_space
self.act_space = act_space
"""
actor 与 critic 网络
"""
self.actor = R_Actor(args, self.obs_space, self.act_space, self.device)
self.critic = R_Critic(args, self.share_obs_space, self.device)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
lr=self.lr, eps=self.opti_eps,
weight_decay=self.weight_decay)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
lr=self.critic_lr,
eps=self.opti_eps,
weight_decay=self.weight_decay)
def lr_decay(self, episode, episodes):
"""
降低 actor 和 critic 的学习率
:param episode: (int) 当前训练的轮次
:param episodes: (int) 总共训练轮次ß
"""
update_linear_schedule(self.actor_optimizer, episode, episodes, self.lr)
update_linear_schedule(self.critic_optimizer, episode, episodes, self.critic_lr)
def get_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, masks, available_actions=None,
deterministic=False):
"""
给定输入,计算动作和值函数的预测
:param cent_obs (np.ndarray): 给 critic 输入集中式的观测
:param obs (np.ndarray): 给 actor 输入局部观测
:param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
:param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
:param masks: (np.ndarray) denotes points at which RNN states should be reset.
:param available_actions: (np.ndarray) denotes which actions are available to agent
(if None, all actions available)
:param deterministic: (bool) whether the action should be mode of distribution or should be sampled.
:return values: (torch.Tensor) value function predictions.
:return actions: (torch.Tensor) actions to take.
:return action_log_probs: (torch.Tensor) log probabilities of chosen actions.
:return rnn_states_actor: (torch.Tensor) updated actor network RNN states.
:return rnn_states_critic: (torch.Tensor) updated critic network RNN states.
"""
actions, action_log_probs, rnn_states_actor = self.actor(obs,
rnn_states_actor,
masks,
available_actions,
deterministic)
values, rnn_states_critic = self.critic(cent_obs, rnn_states_critic, masks)
return values, actions, action_log_probs, rnn_states_actor, rnn_states_critic
def get_values(self, cent_obs, rnn_states_critic, masks):
"""
计算值函数
:param cent_obs (np.ndarray): centralized input to the critic.
:param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
:param masks: (np.ndarray) denotes points at which RNN states should be reset.
:return values: (torch.Tensor) value function predictions.
"""
values, _ = self.critic(cent_obs, rnn_states_critic, masks)
return values
def evaluate_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, action, masks,
available_actions=None, active_masks=None):
"""
计算动作 log-概率 / 熵 和 用于更新 actor 的值函数预测
:param cent_obs (np.ndarray): centralized input to the critic.
:param obs (np.ndarray): local agent inputs to the actor.
:param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
:param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
:param action: (np.ndarray) actions whose log probabilites and entropy to compute.
:param masks: (np.ndarray) denotes points at which RNN states should be reset.
:param available_actions: (np.ndarray) denotes which actions are available to agent
(if None, all actions available)
:param active_masks: (torch.Tensor) denotes whether an agent is active or dead.
:return values: (torch.Tensor) value function predictions.
:return action_log_probs: (torch.Tensor) log probabilities of the input actions.
:return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.
"""
action_log_probs, dist_entropy = self.actor.evaluate_actions(obs,
rnn_states_actor,
action,
masks,
available_actions,
active_masks)
values, _ = self.critic(cent_obs, rnn_states_critic, masks)
return values, action_log_probs, dist_entropy
def act(self, obs, rnn_states_actor, masks, available_actions=None, deterministic=False):
"""
给定输入,计算动作
:param obs (np.ndarray): local agent inputs to the actor.
:param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
:param masks: (np.ndarray) denotes points at which RNN states should be reset.
:param available_actions: (np.ndarray) denotes which actions are available to agent
(if None, all actions available)
:param deterministic: (bool) whether the action should be mode of distribution or should be sampled.
"""
actions, _, rnn_states_actor = self.actor(obs, rnn_states_actor, masks, available_actions, deterministic)
return actions, rnn_states_actor
最后是用来计算损失,更新 actor 和 critic 网络,的训练类,实现在 r_mappo/r_mappo.py
中。
import numpy as np
import torch
import torch.nn as nn
from onpolicy.utils.util import get_gard_norm, huber_loss, mse_loss
from onpolicy.utils.valuenorm import ValueNorm
from onpolicy.algorithms.utils.util import check
class R_MAPPO():
"""
用来更新 MAPPO 的策略的 trainer 类
:param args: (argparse.Namespace) arguments containing relevant model, policy, and env information.
:param policy: (R_MAPPO_Policy) 需要更新的策略
:param device: (torch.device) specifies the device to run on (cpu/gpu).
"""
def __init__(self,
args,
policy,
device=torch.device("cpu")):
self.device = device
self.tpdv = dict(dtype=torch.float32, device=device)
self.policy = policy
self.clip_param = args.clip_param
self.ppo_epoch = args.ppo_epoch
self.num_mini_batch = args.num_mini_batch
self.data_chunk_length = args.data_chunk_length
self.value_loss_coef = args.value_loss_coef
self.entropy_coef = args.entropy_coef
self.max_grad_norm = args.max_grad_norm
self.huber_delta = args.huber_delta
self._use_recurrent_policy = args.use_recurrent_policy
self._use_naive_recurrent = args.use_naive_recurrent_policy
self._use_max_grad_norm = args.use_max_grad_norm
self._use_clipped_value_loss = args.use_clipped_value_loss
self._use_huber_loss = args.use_huber_loss
self._use_popart = args.use_popart
self._use_valuenorm = args.use_valuenorm
self._use_value_active_masks = args.use_value_active_masks
self._use_policy_active_masks = args.use_policy_active_masks
assert (self._use_popart and self._use_valuenorm) == False, ("self._use_popart and self._use_valuenorm can not be set True simultaneously")
if self._use_popart:
self.value_normalizer = self.policy.critic.v_out
elif self._use_valuenorm:
self.value_normalizer = ValueNorm(1, device = self.device)
else:
self.value_normalizer = None
def cal_value_loss(self, values, value_preds_batch, return_batch, active_masks_batch):
"""
计算值函数的损失
:param values: (torch.Tensor) value function predictions.
:param value_preds_batch: (torch.Tensor) 从 data batch 中得到的"旧"的值函数(用来计算值裁剪损失)
:param return_batch: (torch.Tensor) 返回的奖励
:param active_masks_batch: (torch.Tensor) denotes if agent is active or dead at a given timesep.
:return value_loss: (torch.Tensor) value function loss.
"""
value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param,
self.clip_param)
if self._use_popart or self._use_valuenorm:
self.value_normalizer.update(return_batch)
error_clipped = self.value_normalizer.normalize(return_batch) - value_pred_clipped
error_original = self.value_normalizer.normalize(return_batch) - values
else:
error_clipped = return_batch - value_pred_clipped
error_original = return_batch - values
if self._use_huber_loss:
value_loss_clipped = huber_loss(error_clipped, self.huber_delta)
value_loss_original = huber_loss(error_original, self.huber_delta)
else:
value_loss_clipped = mse_loss(error_clipped)
value_loss_original = mse_loss(error_original)
if self._use_clipped_value_loss:
value_loss = torch.max(value_loss_original, value_loss_clipped)
else:
value_loss = value_loss_original
if self._use_value_active_masks:
value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum()
else:
value_loss = value_loss.mean()
return value_loss
def ppo_update(self, sample, update_actor=True):
"""
更新 actor 和 critic 网络
:param sample: (Tuple) 包含了用于更新网络的数据 batch
:param update_actor: (bool) 是否更新 actor 网络
:return value_loss: (torch.Tensor) value function loss.
:return critic_grad_norm: (torch.Tensor) gradient norm from critic update.
:return policy_loss: (torch.Tensor) actor(policy) loss value.
:return dist_entropy: (torch.Tensor) action entropies.
:return actor_grad_norm: (torch.Tensor) gradient norm from actor update.
:return imp_weights: (torch.Tensor) 重要性采样权重
"""
share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \
value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \
adv_targ, available_actions_batch = sample
old_action_log_probs_batch = check(old_action_log_probs_batch).to(**self.tpdv)
adv_targ = check(adv_targ).to(**self.tpdv)
value_preds_batch = check(value_preds_batch).to(**self.tpdv)
return_batch = check(return_batch).to(**self.tpdv)
active_masks_batch = check(active_masks_batch).to(**self.tpdv)
## 在一个前向传递中对所有 steps 进行 reshape
values, action_log_probs, dist_entropy = self.policy.evaluate_actions(share_obs_batch,
obs_batch,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
masks_batch,
available_actions_batch,
active_masks_batch)
"""
更新 actor
"""
imp_weights = torch.exp(action_log_probs - old_action_log_probs_batch)
surr1 = imp_weights * adv_targ
surr2 = torch.clamp(imp_weights, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ
if self._use_policy_active_masks:
policy_action_loss = (-torch.sum(torch.min(surr1, surr2),
dim=-1,
keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum()
else:
policy_action_loss = -torch.sum(torch.min(surr1, surr2), dim=-1, keepdim=True).mean()
policy_loss = policy_action_loss
self.policy.actor_optimizer.zero_grad()
if update_actor:
(policy_loss - dist_entropy * self.entropy_coef).backward()
if self._use_max_grad_norm:
actor_grad_norm = nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.max_grad_norm)
else:
actor_grad_norm = get_gard_norm(self.policy.actor.parameters())
self.policy.actor_optimizer.step()
"""
更新 critic
"""
value_loss = self.cal_value_loss(values, value_preds_batch, return_batch, active_masks_batch)
self.policy.critic_optimizer.zero_grad()
(value_loss * self.value_loss_coef).backward()
if self._use_max_grad_norm:
critic_grad_norm = nn.utils.clip_grad_norm_(self.policy.critic.parameters(), self.max_grad_norm)
else:
critic_grad_norm = get_gard_norm(self.policy.critic.parameters())
self.policy.critic_optimizer.step()
return value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights
def train(self, buffer, update_actor=True):
"""
使用小批量梯度下降进行一次训练更新
:param buffer: (SharedReplayBuffer) 存储了训练数据的缓冲区
:param update_actor: (bool) whether to update actor network.
:return train_info: (dict) contains information regarding training update (e.g. loss, grad norms, etc).
"""
if self._use_popart or self._use_valuenorm:
advantages = buffer.returns[:-1] - self.value_normalizer.denormalize(buffer.value_preds[:-1])
else:
advantages = buffer.returns[:-1] - buffer.value_preds[:-1]
advantages_copy = advantages.copy()
advantages_copy[buffer.active_masks[:-1] == 0.0] = np.nan
## 在求均值、方差时忽略nan
mean_advantages = np.nanmean(advantages_copy)
std_advantages = np.nanstd(advantages_copy)
advantages = (advantages - mean_advantages) / (std_advantages + 1e-5)
train_info = {}
train_info['value_loss'] = 0
train_info['policy_loss'] = 0
train_info['dist_entropy'] = 0
train_info['actor_grad_norm'] = 0
train_info['critic_grad_norm'] = 0
train_info['ratio'] = 0
for _ in range(self.ppo_epoch):
## data_generator
if self._use_recurrent_policy:
data_generator = buffer.recurrent_generator(advantages, self.num_mini_batch, self.data_chunk_length)
elif self._use_naive_recurrent:
data_generator = buffer.naive_recurrent_generator(advantages, self.num_mini_batch)
else:
data_generator = buffer.feed_forward_generator(advantages, self.num_mini_batch)
## 从 data_generator 中采样
for sample in data_generator:
## 进行一次 PPO 更新
value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights \
= self.ppo_update(sample, update_actor)
## 记录训练相关信息
train_info['value_loss'] += value_loss.item()
train_info['policy_loss'] += policy_loss.item()
train_info['dist_entropy'] += dist_entropy.item()
train_info['actor_grad_norm'] += actor_grad_norm
train_info['critic_grad_norm'] += critic_grad_norm
train_info['ratio'] += imp_weights.mean()
## 更新的次数
num_updates = self.ppo_epoch * self.num_mini_batch
for k in train_info.keys():
train_info[k] /= num_updates
return train_info
def prep_training(self):
self.policy.actor.train()
self.policy.critic.train()
def prep_rollout(self):
self.policy.actor.eval()
self.policy.critic.eval()