再一次被本科生吊打!清华本科生开源强化学习平台天授!

Java面试笔试面经、Java技术每天学习一点

Java面试

关注不迷路

开源最前线(ID:OpenSourceTop) 猿妹编译

项目地址:https://github.com/thu-ml/tianshou

深度强化学习(deep RL)近年来取得了令人瞩目的进步,就说2020年至今,就有多个深度学习框架相继开源,清华的Jittor、旷视的MegEngine、华为的Mindspore,国内首个开源深度学习框架PaddlePaddle等。近日,清华大学又新开源了强化学习平台天授。

再一次被本科生吊打!清华本科生开源强化学习平台天授!_第1张图片

这个项目的主要创建者是Jiayi Weng与Minghao Zhang,他们都是清华的本科生。Jiayi Weng今年6月份本科毕业,在此之前作为本科研究者与清华大学苏航、朱军等老师开展强化学习领域的相关研究。Minghao Zhang目前是清华大学软件学院的本科二年级学生,同时还修了数学专业。

是的,你没听错,就是本科生,是不是感觉又一次被吊打了呢?不过这两个人实力可都是杠杠的,就说Jiayi Weng,小升初的暑假就开始写代码。高二作为全国青少年信息学奥林匹克竞赛(NOI)选手进入省队。高中时期就开始钻研微积分、线性代数,大二上学期就加入了朱军教授领导的TSAIL实验室,大三暑假期间更是去到加拿大图灵奖获得者Bengio教授的实验室,深入开展了RL和NLP的研究。

而且这个项目取名为“天授”,这一词语源自《史记》,意为“取天所授而非学自人类”,刻画了强化学习通过与环境进行交互自主学习,而不需要像监督学习一样需要大量人类标注数据。

天授是什么?

天授(Tianshou)是纯基于 PyTorch 代码的强化学习框架,与目前现有基于TensorFlow 的强化学习库不同,天授的类继承并不复杂,API 也不是很繁琐。支持的 RL 算法包括:

  • Policy Gradient (PG)

  • Deep Q-Network (DQN)

  • Double DQN (DDQN) with n-step returns

  • Advantage Actor-Critic (A2C)

  • Deep Deterministic Policy Gradient (DDPG)

  • Proximal Policy Optimization (PPO)

  • Twin Delayed DDPG (TD3)

  • Soft Actor-Critic (SAC)

为什么要选择天授

速度快:天授是一个轻量级的高速强化学习平台。是在笔记本电脑(i7-8750H + GTX1060)上进行的测试。在CartPole-v0任务上,它仅需3秒就可以训练一个倒立摆(CartPole)。

再一次被本科生吊打!清华本科生开源强化学习平台天授!_第2张图片

上图为天授与各大知名 RL 开源平台在 CartPole 与 Pendulum 环境下的速度对比。所有代码均在配置为 i7-8750H + GTX1060 的同一台笔记本电脑上进行测试。

可复现性:天授有其单元测试。每一次单元测试除了基本功能的测试之外,还包括针对所有算法的完整训练过程,也就是说一旦有一个算法没办法 train 出来结果,单元测试不能通过。据我们所知,得益于天授快速的训练机制,天授是目前唯一一个采用这种标准进行单元测试的强化学习框架

模块化:天授将算法分解为四个部分:

  • init:策略初始化。

  • process_fn:处理函数,从回放缓存中处理数据。

  • call:根据观测值计算操作

  • learn:从给定数据包中学习

接口灵活:用户可以定制各种各样的 training 方法。提供示例,方便用户根据自己的需要进行二次开发

如何使用天授?

这是深度Q网络的一个示例。您还可以在test / discrete / test_dqn.py运行完整的脚本。

首先,导入一些相关的软件包:

import gym, torch, numpy as np, torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import tianshou as ts

定义一些超参数:

task = 'CartPole-v0'
lr = 1e-3
gamma = 0.9
n_step = 3
eps_train, eps_test = 0.1, 0.05
epoch = 10
step_per_epoch = 1000
collect_per_step = 10
target_freq = 320
batch_size = 64
train_num, test_num = 8, 100
buffer_size = 20000
writer = SummaryWriter('log/dqn')  # tensorboard is also supported!

环境配置:

# you can also try with SubprocVectorEnv
train_envs = ts.env.VectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = ts.env.VectorEnv([lambda: gym.make(task) for _ in range(test_num)])

建立网络:

class Net(nn.Module):
    def __init__(self, state_shape, action_shape):
        super().__init__()
        self.model = nn.Sequential(*[
            nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True),
            nn.Linear(128, 128), nn.ReLU(inplace=True),
            nn.Linear(128, 128), nn.ReLU(inplace=True),
            nn.Linear(128, np.prod(action_shape))
        ])
    def forward(self, s, state=None, info={}):
        if not isinstance(s, torch.Tensor):
            s = torch.tensor(s, dtype=torch.float)
        batch = s.shape[0]
        logits = self.model(s.view(batch, -1))
        return logits, state

env = gym.make(task)
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=lr)

设置策略和收集器:

policy = ts.policy.DQNPolicy(net, optim, gamma, n_step,
    use_target_network=True, target_update_freq=target_freq)
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(buffer_size))
test_collector = ts.data.Collector(policy, test_envs)

训练:

result = ts.trainer.offpolicy_trainer(
    policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step,
    test_num, batch_size, train_fn=lambda e: policy.set_eps(eps_train),
    test_fn=lambda e: policy.set_eps(eps_test),
    stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer, task=task)
print(f'Finished training! Use {result["duration"]}')

保存/加载策略:

torch.save(policy.state_dict(), 'dqn.pth')
policy.load_state_dict(torch.load('dqn.pth'))

以35帧率观察模型表现:

collector = ts.data.Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
collector.close()

查看保存在tensorboard中的结果:

tensorboard --logdir log/dqn

你可以在Github和PyPI上找到天授的最新版本和其他资料。最后附上相关地址:

PyPI:https://pypi.org/project/tianshou/

Github天授主页:https://github.com/thu-ml/tianshou

你可能感兴趣的:(再一次被本科生吊打!清华本科生开源强化学习平台天授!)