基于stable-baseline3 强化学习DQN的lunar lander的稳定控制

基于stable-baseline3 强化学习DQN的lunar lander的稳定控制

  • 依赖包
  • lunar lander随机初始化action
  • 基于stable-baseline中DQN的实现
    • 模型训练
    • 模型测试
    • 网络架构优化
  • 附录

依赖包

鉴于不同版本的gym与stable-baselines3会产生冲突,在成功的基础上记录:
gym == 0.21.0
stable-baselines3 == 1.6.2
安装代码:

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple gym==0.21.0
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple stable-baselines3[extra]==1.6.2

lunar lander随机初始化action

import gym


# Create environment
env = gym.make("LunarLander-v2")

eposides = 10
for eq in range(eposides):
    obs = env.reset()
    done = False
    rewards = 0
    while not done:
        action = env.action_space.sample()
        obs, reward, done, info = env.step(action)
        env.render()
        rewards += reward
    print(rewards)

随机初始化,视频链接:lunar_lander_random

基于stable-baseline中DQN的实现

模型训练

import gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy


# Create environment
env = gym.make("LunarLander-v2")
# Instantiate the agent
model = DQN("MlpPolicy", env, verbose=1)
# Train the agent and display a progress bar
model.learn(total_timesteps=int(2e5), progress_bar=True)
# Save the agent
model.save("dqn_lunar")

这里已经将训练好的模型给保存为dqn_lunar.zip

模型测试

直接读取模型训练结果,进行测试

import gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy


# Create environment
env = gym.make("LunarLander-v2")
model = DQN.load("dqn_lunar", env=env)


# 测试接口
mean_reward, std_reward = evaluate_policy(
    model,
    model.get_env(),
    deterministic=True,
    render=True,
    n_eval_episodes=10)
print(mean_reward)

自己写测试模块

import gym
from stable_baselines3 import DQN


# Create environment
env = gym.make("LunarLander-v2")
# Instantiate the agent
model = DQN("MlpPolicy", env, verbose=1)
model = DQN.load("dqn_lunar", env=env)


eposides = 10
for eq in range(eposides):
    obs = env.reset()
    done = False
    rewards = 0
    while not done:
        action, _state = model.predict(obs, deterministic=True)
        obs, reward, done, info = env.step(action)
        env.render()
        rewards += reward
    print(rewards)

测试结果:lunar_lander_DQN

网络架构优化

根据上述视频可以看出,在默认的DQN网络及参数,还不能使飞行器稳定停在月球上,将学习率改为5e-4,网络参数改为256,训练次数改为2500,000次,训练代码如下:

import gym
from stable_baselines3 import DQN


# Create environment
env = gym.make("LunarLander-v2")
model = DQN(
    "MlpPolicy",
    env,
    verbose=1,
    learning_rate=5e-4,
    policy_kwargs={'net_arch':[256,256]})
    
model.learn(
    total_timesteps=int(2.5e6),
    progress_bar=True)

model.save("dqn_Net256_lunar_2500K")

模型测试代码如下:

import gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy


# Create environment
env = gym.make("LunarLander-v2")
model = DQN.load("dqn_Net256_lunar_2500K", env=env)

mean_reward, std_reward = evaluate_policy(
    model,
    model.get_env(),
    deterministic=True,
    render=True,
    n_eval_episodes=10)
print(mean_reward)

测试视频:lunar_lander_256_2500K
由视频可以看出,月球车每次都能稳定停留在月球表面。

附录

有问题可以直接查官方文档
stable-baseline3: 手册
gym: 手册

你可能感兴趣的:(python,开发语言,pytorch,人工智能)