本文转载改编自 蓝斯诺特:StableBaselines3强化学习框架简明教程,SB3,Stable Baseline
https://www.bilibili.com/video/BV1ty4y197JE/
https://github.com/lansinuote/StableBaselines3_SimpleCases
Stable Baselines is a set of improved implementations of Reinforcement Learning (RL) algorithms based on OpenAI Baselines.
Stable Baselines3 : PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
Stable Baselines 是基于 OpenAI Baselines 的强化学习算法的增强实现。
Stable Baselines3 是 PyTorch 版本的 Stable Baselines,是强化学习算法的可靠实现。也称为SB3。
相关链接
Stable Baselines3
Stable Baselines
其他资料
Actions gym.spaces
:
Box
: A N-dimensional box that containes every point in the action space.Discrete
: A list of possible actions, where each timestep only one of the actions can be used.MultiDiscrete
: A list of possible actions, where each timestep only one action of each discrete set can be used.MultiBinary
: A list of possible actions, where each timestep any of the actions can be used in any combination.环境要求
pip install stable-baselines3[extra]
extra 安装包括 Tensorboard, OpenCV or atari-py 来在 atari 游戏上进行训练,如果你不需要这些功能,可以如下安装:
pip install stable-baselines3
或:
pip install git+https://github.com/DLR-RM/stable-baselines3
安装 gym,推荐 0.26.2
pip install gym==0.26.2
注意,有些包在安装时,会顺便安装低版本的 gym,建议升级到新版本。
安装成功后,查看版本
import stable_baselines3 as sb3
import gym
sb3.__version__, gym.__version__
测试
所有 sb3中的单元测试,可以使用 pytest
来运行
pip install pytest pytest-cov
make pytest
你也可以使用 pytype
和 mypy
做一个静态类型检测
pip install pytype mypy
make type
使用 flake8
进行 代码样式检查
pip install flake8
make lint
import gym
#定义环境
class MyWrapper(gym.Wrapper):
def __init__(self):
env = gym.make('CartPole-v1')
super().__init__(env)
self.env = env
def reset(self):
state, _ = self.env.reset()
return state
def step(self, action):
state, reward, done, _, info = self.env.step(action)
return state, reward, done, info
env = MyWrapper()
env.reset()
from stable_baselines3 import PPO
#verbose: (int) Verbosity level 0: not output 1: info 2: debug
model = PPO('MlpPolicy', env, verbose=0)
model
#
from stable_baselines3.common.evaluation import evaluate_policy
#测试,前一个数是reward_sum_mean,后一个数是reward_sum_std
evaluate_policy(model, env, n_eval_episodes=20)
#训练
model.learn(total_timesteps=2_0000, progress_bar=True)
evaluate_policy(model, env, n_eval_episodes=20)
# (801.55, 380.8214115566508)
#保存模型
model.save('models/save')
#加载模型
model = PPO.load('models/save')
#如果要继续训练,需要重新给它一个env,因为env在保存模型时是不能保存下来的
model.set_env(env)
import gym
#自定义一个Wrapper
class Pendulum(gym.Wrapper):
def __init__(self):
env = gym.make('Pendulum-v1')
super().__init__(env)
self.env = env
def reset(self):
state, _ = self.env.reset()
return state
def step(self, action):
state, reward, done, _, info = self.env.step(action)
return state, reward, done, info
Pendulum().reset()
#测试一个环境
def test(env, wrap_action_in_list=False):
print(env)
state = env.reset()
over = False
step = 0
while not over:
action = env.action_space.sample()
if wrap_action_in_list:
action = [action]
next_state, reward, over, _ = env.step(action)
if step % 20 == 0:
print(step, state, action, reward)
if step > 200:
break
state = next_state
step += 1
test(Pendulum())
#修改最大步数
class StepLimitWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.current_step = 0
def reset(self):
self.current_step = 0
return self.env.reset()
def step(self, action):
self.current_step += 1
state, reward, done, info = self.env.step(action)
#修改done字段
if self.current_step >= 100:
done = True
return state, reward, done, info
test(StepLimitWrapper(Pendulum()))
import numpy as np
#修改动作空间
class NormalizeActionWrapper(gym.Wrapper):
def __init__(self, env):
#获取动作空间
action_space = env.action_space
#动作空间必须是连续值
assert isinstance(action_space, gym.spaces.Box)
#重新定义动作空间,在正负一之间的连续值
#这里其实只影响env.action_space.sample的返回结果
#实际在计算时,还是正负2之间计算的
env.action_space = gym.spaces.Box(low=-1,
high=1,
shape=action_space.shape,
dtype=np.float32)
super().__init__(env)
def reset(self):
return self.env.reset()
def step(self, action):
#重新缩放动作的值域
action = action * 2.0
if action > 2.0:
action = 2.0
if action < -2.0:
action = -2.0
return self.env.step(action)
test(NormalizeActionWrapper(Pendulum()))
from gym.wrappers import TimeLimit
#修改状态
class StateStepWrapper(gym.Wrapper):
def __init__(self, env):
#状态空间必须是连续值
assert isinstance(env.observation_space, gym.spaces.Box)
#增加一个新状态字段
low = np.concatenate([env.observation_space.low, [0.0]])
high = np.concatenate([env.observation_space.high, [1.0]])
env.observation_space = gym.spaces.Box(low=low,
high=high,
dtype=np.float32)
super().__init__(env)
self.step_current = 0
def reset(self):
self.step_current = 0
return np.concatenate([self.env.reset(), [0.0]])
def step(self, action):
self.step_current += 1
state, reward, done, info = self.env.step(action)
#根据step_max修改done
if self.step_current >= 100:
done = True
return self.get_state(state), reward, done, info
def get_state(self, state):
#添加一个新的state字段
state_step = self.step_current / 100
return np.concatenate([state, [state_step]])
test(StateStepWrapper(Pendulum()))
from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
#使用Monitor Wrapper,会在训练的过程中输出rollout/ep_len_mean和rollout/ep_rew_mean,就是增加些日志
#gym升级到0.26以后失效了,可能是因为使用了自定义的wapper
env = DummyVecEnv([lambda: Monitor(Pendulum())])
A2C('MlpPolicy', env, verbose=1).learn(1000)
Using cpu device
------------------------------------
| time/ | |
| fps | 919 |
| iterations | 100 |
| time_elapsed | 0 |
| total_timesteps | 500 |
| train/ | |
| entropy_loss | -1.44 |
| explained_variance | -0.00454 |
| learning_rate | 0.0007 |
| n_updates | 99 |
| policy_loss | -43.2 |
| std | 1.02 |
| value_loss | 976 |
------------------------------------
------------------------------------
| time/ | |
| fps | 897 |
| iterations | 200 |
| time_elapsed | 1 |
| total_timesteps | 1000 |
| train/ | |
| entropy_loss | -1.43 |
| explained_variance | 4.43e-05 |
| learning_rate | 0.0007 |
| n_updates | 199 |
| policy_loss | -28.2 |
| std | 1.01 |
| value_loss | 936 |
------------------------------------
Out[6]:
<stable_baselines3.a2c.a2c.A2C at 0x7f94cc7fee80>
from stable_baselines3.common.vec_env import VecNormalize, VecFrameStack
#VecNormalize,他会对state和reward进行Normalize
env = DummyVecEnv([Pendulum])
env = VecNormalize(env)
test(env, wrap_action_in_list=True)
import gym
#定义环境
class MyWrapper(gym.Wrapper):
def __init__(self):
env = gym.make('CartPole-v1')
super().__init__(env)
self.env = env
def reset(self):
state, _ = self.env.reset()
return state
def step(self, action):
state, reward, done, _, info = self.env.step(action)
return state, reward, done, info
MyWrapper().reset()
# array([ 1.6411768e-02, 2.0951958e-02, -8.3168052e-05, -3.0546727e-02], dtype=float32)
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
import time
def test_multiple_env(dumm, N):
if dumm:
#DummyVecEnv,在单线程中运行多个环境
env = DummyVecEnv([MyWrapper] * N)
else:
#SubprocVecEnv,在多线程中运行多个环境
env = SubprocVecEnv([MyWrapper] * N, start_method='fork')
start = time.time()
#训练一个模型
model = PPO('MlpPolicy', env, verbose=0).learn(total_timesteps=5000)
print('消耗时间=', time.time() - start)
#关闭环境
env.close()
#测试
return evaluate_policy(model, MyWrapper(), n_eval_episodes=20)
test_multiple_env(dumm=True, N=2)
# 消耗时间= 6.74487829208374
test_multiple_env(dumm=True, N=10)
'''
消耗时间= 12.886700868606567
(337.0, 119.68667427913601)
'''
test_multiple_env(dumm=False, N=2)
'''
消耗时间= 9.097013473510742
(371.3, 279.9510850130787)
'''
test_multiple_env(dumm=False, N=10)
'''
消耗时间= 14.181455612182617
(481.85, 260.54352323556236)
'''
import gym
#定义环境
class MyWrapper(gym.Wrapper):
def __init__(self):
env = gym.make('CartPole-v1')
super().__init__(env)
self.env = env
def reset(self):
state, _ = self.env.reset()
return state
def step(self, action):
state, reward, done, _, info = self.env.step(action)
return state, reward, done, info
env = MyWrapper()
env.reset()
from stable_baselines3.common.callbacks import BaseCallback
#Callback语法
class CustomCallback(BaseCallback):
def __init__(self, verbose=0):
super().__init__(verbose)
#可以访问的变量
#self.model
#self.training_env
#self.n_calls
#self.num_timesteps
#self.locals
#self.globals
#self.logger
#self.parent
def _on_training_start(self) -> None:
#第一个rollout开始前调用
pass
def _on_rollout_start(self) -> None:
#rollout开始前
pass
def _on_step(self) -> bool:
#env.step()之后调用,返回False后停止训练
return True
def _on_rollout_end(self) -> None:
#更新参数前调用
pass
def _on_training_end(self) -> None:
#训练结束前调用
pass
CustomCallback()
from stable_baselines3 import PPO
#让训练只执行N步的callback
class SimpleCallback(BaseCallback):
def __init__(self):
super().__init__(verbose=0)
self.call_count = 0
def _on_step(self):
self.call_count += 1
if self.call_count % 20 == 0:
print(self.call_count)
if self.call_count >= 100:
return False
return True
model = PPO('MlpPolicy', MyWrapper(), verbose=0)
model.learn(8000, callback=SimpleCallback())
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3 import A2C
from stable_baselines3.common.evaluation import evaluate_policy
import gym
def test_callback(callback):
#创建Monitor封装的环境,这会在训练过程中写出日志文件到models文件夹
env = make_vec_env(MyWrapper, n_envs=1, monitor_dir='models')
#等价写法
# from stable_baselines3.common.monitor import Monitor
# from stable_baselines3.common.vec_env import DummyVecEnv
# env = Monitor(MyWrapper(), 'models')
# env = DummyVecEnv([lambda: env])
#训练
model = A2C('MlpPolicy', env, verbose=0).learn(total_timesteps=5000,
callback=callback)
#测试
return evaluate_policy(model, MyWrapper(), n_eval_episodes=20)
#使用Monitor封装的环境训练一个模型,保存下日志
#只是为了测试load_results, ts2xy这两个函数
test_callback(None)
from stable_baselines3.common.results_plotter import load_results, ts2xy
#加载日志,这里找的是models/*.monitor.csv
load_results('models')
index r l t
0 0 32.0 32 0.049117
1 1 20.0 20 0.071115
2 2 23.0 23 0.097009
3 3 27.0 27 0.128370
4 4 41.0 41 0.173688
... ... ... ... ...
79 79 88.0 88 5.220903
80 80 80.0 80 5.312054
81 81 82.0 82 5.447958
82 82 121.0 121 5.603233
83 83 238.0 238 5.901814
84 rows × 4 columns
ts2xy(load_results('models'), 'timesteps')
'''
(array([ 32, 52, 75, 102, 143, 153, 169, 183, 200, 212, 227,
252, 283, 317, 358, 377, 399, 421, 445, 505, 537, 627,
...
4363, 4385, 4473, 4553, 4635, 4756, 4994]),
array([ 32., 20., 23., 27., 41., 10., 16., 14., 17., 12., 15.,
25., 31., 34., 41., 19., 22., 22., 24., 60., 32., 90.,
...
32., 22., 88., 80., 82., 121., 238.]))
'''
#保存最优模型
class SaveOnBestTrainingRewardCallback(BaseCallback):
def __init__(self):
super().__init__(verbose=0)
self.best = -float('inf')
def _on_step(self):
#self.n_calls是个从1开始的数
if self.n_calls % 1000 != 0:
return True
#读取日志
x, y = ts2xy(load_results('models'), 'timesteps')
#求最后100个reward的均值
mean_reward = sum(y[-100:]) / len(y[-100:])
print(self.num_timesteps, self.best, mean_reward)
#判断保存
if mean_reward > self.best:
self.best = mean_reward
print('save', x[-1])
self.model.save('models/best_model')
return True
test_callback(SaveOnBestTrainingRewardCallback())
'''
1000 -inf 37.23076923076923
save 968
2000 37.23076923076923 45.86046511627907
save 1972
3000 45.86046511627907 54.870370370370374
save 2963
4000 54.870370370370374 58.705882352941174
save 3992
5000 58.705882352941174 65.13333333333334
save 4885
'''
#可以打印或者画图的callback
class PlottingCallback(BaseCallback):
def __init__(self, verbose=0):
super().__init__(verbose=0)
def _on_step(self) -> bool:
if self.n_calls % 1000 != 0:
return True
x, y = ts2xy(load_results('models'), 'timesteps')
print(self.n_calls)
print('x=', x)
print('y=', y)
return True
test_callback(PlottingCallback()) # (101.4, 16.73439571660716)
'''
1000
x= [ 50 80 110 136 167 204 240 270 295 322 391 515 601 620 662 699 738 764
805 838 873 904 929 954 979]
y= [ 50. 30. 30. 26. 31. 37. 36. 30. 25. 27. 69. 124. 86. 19.
42. 37. 39. 26. 41. 33. 35. 31. 25. 25. 25.]
2000
x= [ 50 80 110 136 167 204 240 270 295 322 391 515 601 620
...
1632 1692 1752 1781 1814 1850 1898 1918 1959]
y= [ 50. 30. 30. 26. 31. 37. 36. 30. 25. 27. 69. 124. 86. 19.
...
41. 60. 60. 29. 33. 36. 48. 20. 41.]
42.
3000
x= [ 50 80 110 136 167 204 240 270 295 322 391 515 601 620
662 699 738 764 805 838 873 904 929 954 979 1021 1042 1058
...
2227 2264 2308 2351 2387 2431 2475 2508 2558 2765 2802 2842 2889 2999]
y= [ 50. 30. 30. 26. 31. 37. 36. 30. 25. 27. 69. 124. 86. 19.
42. 37. 39. 26. 41. 33. 35. 31. 25. 25. 25. 42. 21. 16.
...
33. 37. 44. 43. 36. 44. 44. 33. 50. 207. 37. 40. 47. 110.]
4000
x= [ 50 80 110 136 167 204 240 270 295 322 391 515 601 620
662 699 738 764 805 838 873 904 929 954 979 1021 1042 1058
...
3088 3197 3311 3359 3567 3798 3924]
y= [ 50. 30. 30. 26. 31. 37. 36. 30. 25. 27. 69. 124. 86. 19.
42. 37. 39. 26. 41. 33. 35. 31. 25. 25. 25. 42. 21. 16.
...
89. 109. 114. 48. 208. 231. 126.]
5000
x= [ 50 80 110 136 167 204 240 270 295 322 391 515 601 620
662 699 738 764 805 838 873 904 929 954 979 1021 1042 1058
...
4976]
y= [ 50. 30. 30. 26. 31. 37. 36. 30. 25. 27. 69. 124. 86. 19.
42. 37. 39. 26. 41. 33. 35. 31. 25. 25. 25. 42. 21. 16.
...
128.]
'''
from tqdm.auto import tqdm
#更新进度条的callback
class ProgressBarCallback(BaseCallback):
def __init__(self):
super().__init__()
self.pbar = tqdm(total=5000)
def _on_step(self):
self.pbar.update(1)
def _on_training_end(self) -> None:
self.pbar.close()
test_callback(ProgressBarCallback())
# (175.3, 13.849548729110273)
#同时使用多个callback
test_callback([PlottingCallback(), ProgressBarCallback()])
# (737.35, 255.4991731884861)
import numpy as np
import gym
from stable_baselines3.common.env_checker import check_env
class GoLeftEnv(gym.Env):
#支持的render模式,在jupyter中不支持human模式
metadata = {'render.modes': ['console']}
def __init__(self):
super().__init__()
#初始位置
self.pos = 9
#动作空间,这个环境中只有左,右两个动作
self.action_space = gym.spaces.Discrete(2)
#状态空间,一维数轴
self.observation_space = gym.spaces.Box(low=0,
high=10,
shape=(1, ),
dtype=np.float32)
def reset(self):
#重置位置
self.pos = 9
#当前状态
return np.array([self.pos], dtype=np.float32)
def step(self, action):
#执行动作
if action == 0:
self.pos -= 1
if action == 1:
self.pos += 1
self.pos = np.clip(self.pos, 0, 10)
#判断游戏结束
done = self.pos == 0
#给予reward
reward = 1 if self.pos == 0 else 0
return np.array([self.pos], dtype=np.float32), reward, bool(done), {}
def render(self, mode='console'):
if mode != 'console':
raise NotImplementedError()
print(self.pos)
def close(self):
pass
env = GoLeftEnv()
#检查环境是否合法
check_env(env, warn=True)
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
#包装环境
train_env = make_vec_env(lambda: env, n_envs=1)
#定义模型
model = PPO('MlpPolicy', train_env, verbose=0)
import gym
#测试一个环境
def test(model, env):
state = env.reset()
over = False
step = 0
for i in range(100):
action = model.predict(state)[0]
next_state, reward, over, _ = env.step(action)
if step % 1 == 0:
print(step, state, action, reward)
state = next_state
step += 1
if over:
break
test(model, env)
'''
0 [9.] 1 0
1 [10.] 0 0
2 [9.] 0 0
3 [8.] 0 0
4 [7.] 0 0
5 [6.] 1 0
...
35 [2.] 0 0
36 [1.] 0 1
'''
model.learn(5000)
#测试
test(model, env)
'''
0 [9.] 0 0
1 [8.] 0 0
2 [7.] 0 0
3 [6.] 0 0
4 [5.] 0 0
5 [4.] 0 0
6 [3.] 0 0
7 [2.] 0 0
8 [1.] 0 1
'''
import gym
#定义环境
class MyWrapper(gym.Wrapper):
def __init__(self):
env = gym.make('LunarLander-v2')
super().__init__(env)
self.env = env
def reset(self):
state, _ = self.env.reset()
return state
def step(self, action):
state, reward, done, _, info = self.env.step(action)
return state, reward, done, info
env = MyWrapper()
env.reset()
'''
array([-3.9796828e-04, 1.3983431e+00, -4.0319901e-02, -5.5897278e-01,
4.6789032e-04, 9.1330688e-03, 0.0000000e+00, 0.0000000e+00],
dtype=float32)
'''
#认识游戏环境
def test_env():
print('env.observation_space=', env.observation_space)
print('env.action_space=', env.action_space)
state = env.reset()
action = env.action_space.sample()
next_state, reward, done, _ = env.step(action)
print('state=', state)
print('action=', action)
print('next_state=', next_state)
print('reward=', reward)
print('done=', done)
test_env()
'''
env.observation_space= Box([-1.5 -1.5 -5. -5. -3.1415927 -5. -0. -0. ],
[1.5 1.5 5. 5. 3.1415927 5. 1. 1. ], (8,), float32)
env.action_space= Discrete(4)
state= [-0.00586414 1.4207945 -0.593993 0.43883783 0.0068019 0.13454841 0. 0. ]
action= 3
next_state= [-0.01165171 1.4300904 -0.58351576 0.41312727 0.01151062 0.09418334 0. 0. ]
reward= 0.921825805844976
done= False
'''
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3 import PPO
#初始化模型
model = PPO(
policy='MlpPolicy',
env=make_vec_env(MyWrapper, n_envs=4), #创建N个环境用于训练
n_steps=1024,
batch_size=64,
n_epochs=4,
gamma=0.999,
gae_lambda=0.98,
ent_coef=0.01,
verbose=0)
model
#
from stable_baselines3.common.evaluation import evaluate_policy
#测试
evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)
# (-236.2219022285426, 122.28696606172053)
#训练
model.learn(total_timesteps=20_0000, progress_bar=True)
model.save('models/ppo-LunarLander-v2')
model = PPO.load('models/ppo-LunarLander-v2')
evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)
# (45.54798891161171, 139.1048836822021)
from huggingface_sb3 import load_from_hub
#!pip install huggingface-sb3
#加载其他训练好的模型
#https://huggingface.co/models?library=stable-baselines3
model = PPO.load(
load_from_hub('araffin/ppo-LunarLander-v2', 'ppo-LunarLander-v2.zip'),
custom_objects={
'learning_rate': 0.0,
'lr_schedule': lambda _: 0.0,
'clip_range': lambda _: 0.0,
},
print_system_info=True,
)
evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)
# (250.9974542026721, 86.61020518339575)
'''
Downloading: 0%| | 0.00/144k [00:00, ?B/s]
== CURRENT SYSTEM INFO ==
- OS: Linux-5.15.0-3.60.5.1.el9uek.x86_64-x86_64-with-glibc2.34 # 2 SMP Wed Oct 19 20:27:31 PDT 2022
- Python: 3.9.15
- Stable-Baselines3: 1.8.0a1
- PyTorch: 1.13.0+cpu
- GPU Enabled: False
- Numpy: 1.23.5
- Gym: 0.26.2
== SAVED MODEL SYSTEM INFO ==
OS: Linux-5.13.0-40-generic-x86_64-with-debian-bullseye-sid #45~20.04.1-Ubuntu SMP Mon Apr 4 09:38:31 UTC 2022
Python: 3.7.10
Stable-Baselines3: 1.5.1a5
PyTorch: 1.11.0
GPU Enabled: False
Numpy: 1.21.2
Gym: 0.21.0
'''
import gym
#定义环境
class MyWrapper(gym.Wrapper):
def __init__(self):
env = gym.make('CartPole-v1')
super().__init__(env)
self.env = env
def reset(self):
state, _ = self.env.reset()
return state
def step(self, action):
state, reward, done, _, info = self.env.step(action)
return state, reward, done, info
MyWrapper().reset()
# array([-0.02186339, -0.00224868, -0.04336443, -0.00508288], dtype=float32)
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.monitor import Monitor
#创建训练环境和测试环境
env_train = make_vec_env(MyWrapper, n_envs=4)
env_test = Monitor(MyWrapper())
env_train, env_test
# (,
# >>>>>>)
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
#测试超参数
def test_params(params):
#定义一个模型
model = PPO(
policy='MlpPolicy',
env=env_train,
n_steps=1024,
batch_size=64,
#取超参数
n_epochs=params['n_epochs'],
#取超参数
gamma=params['gamma'],
gae_lambda=0.98,
ent_coef=0.01,
verbose=0,
)
#训练
#取超参数
model.learn(total_timesteps=params['total_timesteps'], progress_bar=True)
#测试
mean_reward, std_reward = evaluate_policy(model,
env_test,
n_eval_episodes=50,
deterministic=True)
#最终的分数就是简单的求差,这也是study要优化的数
score = mean_reward - std_reward
return score
test_params({'n_epochs': 2, 'gamma': 0.99, 'total_timesteps': 500})
import optuna
from optuna.samplers import TPESampler
#定义一个超参数学习器
study = optuna.create_study(sampler=TPESampler(),
study_name='PPO-LunarLander-v2',
direction='maximize')
#求最优超参数
def f(trial):
#定义要找的超参数,并设置上下限
params = {
'n_epochs': trial.suggest_int('n_epochs', 3, 5),
'gamma': trial.suggest_uniform('gamma', 0.99, 0.9999),
'total_timesteps': trial.suggest_int('total_timesteps', 500, 2000),
}
#测试超参数
return test_params(params)
study.optimize(f, n_trials=5)
#输出最佳分数和超参数
study.best_trial.values, study.best_trial.params
# ([102.40188282816155],
# {'n_epochs': 5, 'gamma': 0.9963042134418639, 'total_timesteps': 580})
#用最优超参数训练一个模型
test_params(study.best_trial.params)
# 133.760702749365
import gym
#定义环境
class MyWrapper(gym.Wrapper):
def __init__(self):
env = gym.make('CartPole-v1')
super().__init__(env)
self.env = env
def reset(self):
state, _ = self.env.reset()
return state
def step(self, action):
state, reward, done, _, info = self.env.step(action)
return state, reward, done, info
env = MyWrapper()
env.reset()
# array([-0.04618305, -0.0019841 , -0.02022721, 0.04200636], dtype=float32)
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
#自定义特征抽取层
class CustomCNN(BaseFeaturesExtractor):
def __init__(self, observation_space, hidden_dim):
super().__init__(observation_space, hidden_dim)
self.sequential = torch.nn.Sequential(
#[b, 4, 1, 1] -> [b, h, 1, 1]
torch.nn.Conv2d(in_channels=observation_space.shape[0],
out_channels=hidden_dim,
kernel_size=1,
stride=1,
padding=0),
torch.nn.ReLU(),
#[b, h, 1, 1] -> [b, h, 1, 1]
torch.nn.Conv2d(hidden_dim,
hidden_dim,
kernel_size=1,
stride=1,
padding=0),
torch.nn.ReLU(),
#[b, h, 1, 1] -> [b, h]
torch.nn.Flatten(),
#[b, h] -> [b, h]
torch.nn.Linear(hidden_dim, hidden_dim),
torch.nn.ReLU(),
)
def forward(self, state):
b = state.shape[0]
state = state.reshape(b, -1, 1, 1)
return self.sequential(state)
model = PPO('CnnPolicy',
env,
policy_kwargs={
'features_extractor_class': CustomCNN,
'features_extractor_kwargs': {
'hidden_dim': 8
},
},
verbose=0)
model
#
from stable_baselines3.common.evaluation import evaluate_policy
#测试
evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)
# (21.1, 9.289241088485108)
#训练
model.learn(total_timesteps=2_0000, progress_bar=True)
model.save('models/自定义特征抽取层')
model = PPO.load('models/自定义特征抽取层')
evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)
# (522.7, 428.98812337872477)
import gym
#定义环境
class MyWrapper(gym.Wrapper):
def __init__(self):
env = gym.make('CartPole-v1')
super().__init__(env)
self.env = env
def reset(self):
state, _ = self.env.reset()
return state
def step(self, action):
state, reward, done, _, info = self.env.step(action)
return state, reward, done, info
env = MyWrapper()
env.reset()
# array([-0.02676624, -0.00992495, -0.02703292, -0.02751946], dtype=float32)
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy
#自定义策略网络
class CustomNetwork(torch.nn.Module):
def __init__(self,
feature_dim,
last_layer_dim_pi=64,
last_layer_dim_vf=64):
super().__init__()
self.latent_dim_pi = last_layer_dim_pi
self.latent_dim_vf = last_layer_dim_vf
self.policy_net = torch.nn.Sequential(
torch.nn.Linear(feature_dim, last_layer_dim_pi),
torch.nn.ReLU(),
)
self.value_net = torch.nn.Sequential(
torch.nn.Linear(feature_dim, last_layer_dim_vf),
torch.nn.ReLU(),
)
def forward(self, features):
return self.forward_actor(features), self.forward_critic(features)
def forward_actor(self, features):
return self.policy_net(features)
def forward_critic(self, features):
return self.value_net(features)
#使用自定义策略网络
class CustomActorCriticPolicy(ActorCriticPolicy):
def __init__(self, observation_space, action_space, lr_schedule,
custom_param, *args, **kwargs):
super().__init__(observation_space, action_space, lr_schedule, *args,
**kwargs)
print('custom_param=', custom_param)
self.ortho_init = False
def _build_mlp_extractor(self) -> None:
self.mlp_extractor = CustomNetwork(self.features_dim)
model = PPO(CustomActorCriticPolicy,
env,
policy_kwargs={'custom_param': 'lee'},
verbose=0)
model
# custom_param= lee
#
from stable_baselines3.common.evaluation import evaluate_policy
#测试
evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)
# (25.6, 13.440238093129155)
#训练
model.learn(total_timesteps=2_0000, progress_bar=True)
model.save('models/自定义策略网络层')
model = PPO.load('models/自定义策略网络层')
evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)
# (241.8, 92.53085971717759)
安装 sb3-contrib
pip install sb3-contrib
#会自动覆盖安装低版本gym,需要手动重新安装gym的0.26.2版本
pip install gym==0.26.2
import gym
#定义环境
class MyWrapper(gym.Wrapper):
def __init__(self):
env = gym.make('CartPole-v1')
super().__init__(env)
self.env = env
def reset(self):
state, _ = self.env.reset()
return state
def step(self, action):
state, reward, done, _, info = self.env.step(action)
return state, reward, done, info
env = MyWrapper()
env.reset()
# array([-0.02218479, 0.03097541, -0.04123801, 0.02780065], dtype=float32)
from stable_baselines3.common.env_util import make_vec_env
from sb3_contrib import TRPO
model = TRPO(policy='MlpPolicy', env=env, verbose=0)
model
#
from stable_baselines3.common.evaluation import evaluate_policy
#测试
evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)
# (23.5, 7.710382610480495)
#训练
model.learn(total_timesteps=2_0000, progress_bar=True)
model.save('models/使用SB3 Contrib')
model = TRPO.load('models/使用SB3 Contrib')
evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)
# (505.8, 269.82542504367524)
2023-02-17(五)