基于stable-baselines3的PPO和DQN训练LunarLander-v2

文章目录

  • stable-baselines3
    • 配置stable-baselines3环境
  • LunarLander-v2
    • 配置LunarLander-v2环境
  • PPO方法
  • DQN方法


stable-baselines3

Stable Baselines3 (SB3) is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of Stable Baselines.

Stable-baselines3 github
Stable-baselines3 Docs

配置stable-baselines3环境

$pip install stable_baselines3

LunarLander-v2

我们要训练的openai gym的场景是LunarLander-v2,是模拟月球车登月时制动着陆的过程。月球车的状态空间是一个8维向量,每一个维度都是连续值;动作空间为离散空间,可选值为0,1,2,3,分别代表熄火,启动左引擎,启动主引擎和启动右引擎。

配置LunarLander-v2环境

$pip install gym
$pip install Box2D

PPO方法

import gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv

env_name = "LunarLander-v2"
env = gym.make(env_name)
env = DummyVecEnv([lambda : env])

model = PPO("MlpPolicy", 
            env=env, 
            batch_size=64,
            gae_lambda=0.98,
            gamma=0.999,
            n_epochs=4,
            ent_coef=0.01,
            verbose=1,
            tensorboard_log="./tensorboard/LunarLander-v2/"
)

model.learn(total_timesteps=1e6)

model.save("./model/LunarLander_PPO.pkl")

env = gym.make(env_name)
model = PPO.load("./model/LunarLander_PPO.pkl")

state = env.reset()
done = False 
score = 0
while not done:
    action, _ = model.predict(observation=state)
    state, reward, done, info = env.step(action=action)
    score += reward
    env.render()
env.close()
score

DQN方法

import gym
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv

env_name = "LunarLander-v2"
env = gym.make(env_name)
env = DummyVecEnv([lambda : env])

model = DQN(
    "MlpPolicy", 
    env=env, 
    learning_rate=5e-4,
    batch_size=128,
    buffer_size=50000,
    learning_starts=0,
    target_update_interval=250,
    policy_kwargs={"net_arch" : [256, 256]},
    verbose=1,
    tensorboard_log="./tensorboard/LunarLander-v2/"
)

model.learn(total_timesteps=1e6)

model.save("./model/LunarLander_DQN.pkl")

env = gym.make(env_name)
model = DQN.load("./model/LunarLander_DQN.pkl")

state = env.reset()
done = False 
score = 0
while not done:
    action, _ = model.predict(observation=state)
    state, reward, done, info = env.step(action=action)
    score += reward
    env.render()
env.close()
score

你可能感兴趣的:(强化学习,python,机器学习,人工智能)