MAPPO是2021年一篇将PPO算法扩展至多智能体的论文,其论文链接地址为:https://arxiv.org/abs/2103.01955
对应的官方代码链接为https://github.com/marlbenchmark/on-policy
所有核心代码都位于 onpolicy 文件夹中。 algorithms/子文件夹包含 MAPPO 的特定于算法的代码。
envs/ 文件夹包含 MPE、SMAC 和 Hanabi 的环境的实现。
用于执行训练部署和策略更新的代码包含在 runner/ 文件夹中 - 每个环境都有一个runner。
可以在 scripts/ 文件夹中找到用于使用默认超参数进行训练的可执行脚本。 这些文件按以下方式命名:train_algo_environment.sh。 在每个文件中,映射名称(在 SMAC 和 MPE 的情况下)可以更改。
可以在 scripts/train/ 文件夹中找到每个环境的 Python 训练脚本。
config.py 文件包含相关的超参数和环境设置。
本配置在Ubuntu 16/18/20,可以在CPU或GPU上跑程序,本文按照GPU版本配置环境
1.出现ImportError: cannot import name ‘get_backend’
解决方法:tensorflow版本问题:sudo pip install tensorflow --upgrade,如果还不行查看:https://stackoverflow.com/questions/54574591/importerror-cannot-import-name-backend以及https://stackoverflow.com/questions/52979322/matplotlib-3-0-0-cannot-import-name-get-backend-from-matplotlib
2.出现TypeError: cannot assign ‘torch.FloatTensor’ as parameter ‘stddev’ (torch.nn.Parameter or None expected)
解决方法:代码的错误,将"onpolicy/algorithms/utils/popart.py"中的 63行代码
self.stddev = (self.mean_sq - self.mean ** 2).sqrt().clamp(min=1e-4)
self.weight = self.weight * old_stddev / self.stddev
self.bias = (old_stddev * self.bias + old_mean - self.mean) / self.stddev
改为:
self.stddev = nn.Parameter((self.mean_sq - self.mean ** 2).sqrt().clamp(min=1e-4))
self.weight = nn.Parameter(self.weight * old_stddev / self.stddev)
self.bias = nn.Parameter((old_stddev * self.bias + old_mean - self.mean) / self.stddev)
最后就可以运行程序了!
为了更好的控制代码,能够便利的debug需要对train_mpe.py做些更改:需要在main()中更改
if __name__ == "__main__":
if len(sys.argv[1:]) == 0:
argv = ['--use_valuenorm', '--use_popart', '--env_name', 'MPE', '--algorithm_name', 'rmappo',
'--experiment_name', 'check', '--scenario_name', 'simple_spread', '--num_agents', '3',
'--num_landmarks', '3', '--seed', '1', '--n_training_threads', '1', '--n_rollout_threads',
'64', '--num_mini_batch', '1', '--episode_length', '25', '--num_env_steps', '2000000',
'--ppo_epoch', '10', '--use_ReLU', '--gain', '0.01', '--lr', '7e-4', '--critic_lr',
'7e-4', '--wandb_name', 'zoeyuchao', '--user_name', 'zoeyuchao', '--use_wandb', 'False']
else:
argv = sys.argv[1:]
print(argv)
main(argv)
这样就能很好的实现debug了!