gym_super_mario_bros
nes_py
matplotlib
pytorch
stable_baselines3
numpy
nes_py 库是任天堂开发的专门用于控制游戏的工具,类似于游戏手柄。
gym_super_mario_bros里面有各个关卡的环境模型,这里选用SuperMarioBros-v0第一关。
stable_baselines3是对强化学习新手及其友好的库,基于pytorch进行开发,可以方便的使用各种常用的强化学习算法的代码。
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
import time
from matplotlib import pyplot as plt
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
from stable_baselines3 import PPO
from gym.wrappers import GrayScaleObservation
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.results_plotter import load_results, ts2xy
import numpy as np
import os
from stable_baselines3.common.callbacks import BaseCallback
env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, SIMPLE_MOVEMENT)
log_dir = './monitor_log/'
os.makedirs(log_dir, exist_ok=True)
env = Monitor(env, log_dir)
env = GrayScaleObservation(env,keep_dim=True)
env = DummyVecEnv([lambda: env])
env = VecFrameStack(env,4,channels_order='last')
class SaveOnStepCallback(BaseCallback):
def __init__(self, check_freq, save_path, verbose=1):
super(SaveOnStepCallback, self).__init__(verbose)
self.check_freq = check_freq
self.save_path = os.path.join(save_path, 'best_model')
def _init_callback(self):
# Create folder if needed
if self.save_path is not None:
os.makedirs(self.save_path, exist_ok=True)
def _on_step(self):
if self.n_calls % self.check_freq == 0:
self.model.save(os.path.join(self.save_path, '_{}'.format(self.n_calls)))
return True
learning_rate = 1e-6
n_steps = 2048
tensorboard_log = r'./tensorboard_logs/'
model = PPO("CnnPolicy", env, verbose=1,
learning_rate=learning_rate,n_steps=n_steps,
tensorboard_log=tensorboard_log)
save_path=r"RL_Mario\\"
callback1=SaveOnStepCallback(check_freq=20000,save_path=save_path)
model.learn(total_timesteps=5000000,callback=callback1)
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
import time
from matplotlib import pyplot as plt
from gym.wrappers import GrayScaleObservation
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.vec_env import VecFrameStack
import os
from stable_baselines3 import PPO
import time
from stable_baselines3.common.results_plotter import load_results, ts2xy
import numpy as np
from stable_baselines3.common.callbacks import BaseCallback
env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, SIMPLE_MOVEMENT)
monitor_dir = r'./monitor_log/'
os.makedirs(monitor_dir,exist_ok=True)
env = Monitor(env,monitor_dir)
env = GrayScaleObservation(env,keep_dim=True)
env = DummyVecEnv([lambda: env])
env = VecFrameStack(env,4,channels_order='last')
save_model_dir = r'E:\\Python\\Mario_play\\RL_Mario\\best_model\\_2320000.zip'
model = PPO.load(save_model_dir)
obs = env.reset()
obs=obs.copy()
done = True
while True:
if done:
state = env.reset()
action, _states = model.predict(obs)
obs, rewards, done, info = env.step(action)
time.sleep(0.01)
obs=obs.copy()
env.render()
强化学习马里奥