Life是一个基于pytorch实现的强化学习库,实现了多种强化学习算法。
项目地址:https://github.com/HanggeAi/Life
训练器的名称和算法的名称是一一对应的,如要训练DQN
,则其训练函数的名称为:
train_dqn
要使用Life进行强化学习,仅需简单的三步,下面以DQN在CartPole环境上的训练为例进行快速入门:
from life.dqn.dqn import DQN # 导入模型
from life.dqn.trainer import train_dqn # 导入训练器
from life.envs.dis_env_demo import make # 环境的一个例子
from life.utils.replay.replay_buffer import ReplayBuffer # 回放池
import torch
import matplotlib.pyplot as plt
# 设置超参数
lr = 2e-3
num_episodes = 500
hidden_dim = 128
gamma = 0.98
epsilon = 0.01
target_update = 10
buffer_size = 10000
minimal_size = 500
batch_size = 64
device = torch.device("cpu") # 也可指定为gpu : torch.device("cuda")
env=make() # 建立环境,这里为 CartPole-v0
replay_buffer = ReplayBuffer(buffer_size) # 回放池
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
# 建立模型
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,
target_update, device) # DQN模型
注意,如果你足够细心,你会发现在上述建立DQN的过程中,我们没有传入一个Neural Network,这是因为在建立深度强化学习时,Life提供了一个默认的双层神经网络作为建立DQN的默认神经网络。当然,你也可以使用自己设计的神经网络结构:
class YourNet:
"""your network for your task"""
pass
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,
target_update, device, q_net=YourNet) # DQN模型
result=train_dqn(agent,env,replay_buffer,minimal_size,batch_size,num_episodes)
episodes_list = list(range(len(result)))
plt.figure(figsize=(8,6))
plt.plot(episodes_list, result)
plt.xlabel("Episodes")
plt.ylabel("Returns")
plt.title("DQN on {}".format("Cart Pole v1"))
plt.show()
return_agent=True
,这会返回一个元组(return_list, agent)
其中,return_list
为:训练过程中每个回合的汇报,agent
为训练好的智能体。
return_agent
默认为False
可见,除了超参数的设置之外,我们构建DQN算法只使用了两行代码:
from life.dqn.dqn import DQN
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,target_update, device)
我们训练DQN同样只使用了两行代码:
from life.dqn.trainer import train_dqn
result=train_dqn(agent,env,replay_buffer,minimal_size,batch_size,num_episodes)