鉴于不同版本的gym与stable-baselines3会产生冲突,在成功的基础上记录:
gym == 0.21.0
stable-baselines3 == 1.6.2
安装代码:
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple gym==0.21.0
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple stable-baselines3[extra]==1.6.2
import gym
# Create environment
env = gym.make("LunarLander-v2")
eposides = 10
for eq in range(eposides):
obs = env.reset()
done = False
rewards = 0
while not done:
action = env.action_space.sample()
obs, reward, done, info = env.step(action)
env.render()
rewards += reward
print(rewards)
随机初始化,视频链接:lunar_lander_random
import gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
# Create environment
env = gym.make("LunarLander-v2")
# Instantiate the agent
model = DQN("MlpPolicy", env, verbose=1)
# Train the agent and display a progress bar
model.learn(total_timesteps=int(2e5), progress_bar=True)
# Save the agent
model.save("dqn_lunar")
这里已经将训练好的模型给保存为dqn_lunar.zip
直接读取模型训练结果,进行测试
import gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
# Create environment
env = gym.make("LunarLander-v2")
model = DQN.load("dqn_lunar", env=env)
# 测试接口
mean_reward, std_reward = evaluate_policy(
model,
model.get_env(),
deterministic=True,
render=True,
n_eval_episodes=10)
print(mean_reward)
自己写测试模块
import gym
from stable_baselines3 import DQN
# Create environment
env = gym.make("LunarLander-v2")
# Instantiate the agent
model = DQN("MlpPolicy", env, verbose=1)
model = DQN.load("dqn_lunar", env=env)
eposides = 10
for eq in range(eposides):
obs = env.reset()
done = False
rewards = 0
while not done:
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render()
rewards += reward
print(rewards)
测试结果:lunar_lander_DQN
根据上述视频可以看出,在默认的DQN网络及参数,还不能使飞行器稳定停在月球上,将学习率改为5e-4,网络参数改为256,训练次数改为2500,000次,训练代码如下:
import gym
from stable_baselines3 import DQN
# Create environment
env = gym.make("LunarLander-v2")
model = DQN(
"MlpPolicy",
env,
verbose=1,
learning_rate=5e-4,
policy_kwargs={'net_arch':[256,256]})
model.learn(
total_timesteps=int(2.5e6),
progress_bar=True)
model.save("dqn_Net256_lunar_2500K")
模型测试代码如下:
import gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
# Create environment
env = gym.make("LunarLander-v2")
model = DQN.load("dqn_Net256_lunar_2500K", env=env)
mean_reward, std_reward = evaluate_policy(
model,
model.get_env(),
deterministic=True,
render=True,
n_eval_episodes=10)
print(mean_reward)
测试视频:lunar_lander_256_2500K
由视频可以看出,月球车每次都能稳定停留在月球表面。
有问题可以直接查官方文档
stable-baseline3: 手册
gym: 手册