基于强化学习的超级马里奥代码实现

环境

gym_super_mario_bros
nes_py
matplotlib
pytorch
stable_baselines3
numpy

简单介绍

nes_py 库是任天堂开发的专门用于控制游戏的工具,类似于游戏手柄。
gym_super_mario_bros里面有各个关卡的环境模型,这里选用SuperMarioBros-v0第一关。
stable_baselines3是对强化学习新手及其友好的库,基于pytorch进行开发,可以方便的使用各种常用的强化学习算法的代码。

代码


from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
import time
from matplotlib import pyplot as plt
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
from stable_baselines3 import PPO

from gym.wrappers import GrayScaleObservation

from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.results_plotter import load_results, ts2xy
import numpy as np
import os
from stable_baselines3.common.callbacks import BaseCallback

env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, SIMPLE_MOVEMENT)


log_dir = './monitor_log/'
os.makedirs(log_dir, exist_ok=True)

env = Monitor(env, log_dir)

env = GrayScaleObservation(env,keep_dim=True)
env = DummyVecEnv([lambda: env])
env = VecFrameStack(env,4,channels_order='last')


class SaveOnStepCallback(BaseCallback):
    def __init__(self, check_freq, save_path, verbose=1):
        super(SaveOnStepCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = os.path.join(save_path, 'best_model')

    def _init_callback(self):
        # Create folder if needed
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            self.model.save(os.path.join(self.save_path, '_{}'.format(self.n_calls)))

        return True

learning_rate = 1e-6
n_steps = 2048

tensorboard_log = r'./tensorboard_logs/'
model = PPO("CnnPolicy", env, verbose=1,
            learning_rate=learning_rate,n_steps=n_steps,
            tensorboard_log=tensorboard_log)

save_path=r"RL_Mario\\"
callback1=SaveOnStepCallback(check_freq=20000,save_path=save_path)
model.learn(total_timesteps=5000000,callback=callback1)

模型测试

from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
import time
from matplotlib import pyplot as plt
from gym.wrappers import GrayScaleObservation
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.vec_env import VecFrameStack
import os
from stable_baselines3 import PPO
import time
from stable_baselines3.common.results_plotter import load_results, ts2xy
import numpy as np
from stable_baselines3.common.callbacks import BaseCallback

env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, SIMPLE_MOVEMENT)

monitor_dir = r'./monitor_log/'
os.makedirs(monitor_dir,exist_ok=True)
env = Monitor(env,monitor_dir)

env = GrayScaleObservation(env,keep_dim=True)
env = DummyVecEnv([lambda: env])
env = VecFrameStack(env,4,channels_order='last')

save_model_dir = r'E:\\Python\\Mario_play\\RL_Mario\\best_model\\_2320000.zip'

model = PPO.load(save_model_dir)

obs = env.reset()
obs=obs.copy()
done = True
while True:
    if done:
        state = env.reset()
    action, _states = model.predict(obs)
    obs, rewards, done, info = env.step(action)
    time.sleep(0.01)
    obs=obs.copy()
    env.render()

效果

强化学习马里奥

你可能感兴趣的:(Python,python,开发语言,后端)