atari env 同时显示4张图片 (array拼接)

from stable_baselines.common.cmd_util import make_atari_env
from stable_baselines.common.vec_env import VecFrameStack
import os
import numpy as np
import matplotlib.pyplot as plt

os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

if __name__ == '__main__':
    env_name = 'PongNoFrameskip-v4'
    env = make_atari_env(env_name, num_env=1, seed=0)
    env = VecFrameStack(env, n_stack=4)
    
    obs = env.reset()  # shape: (1, 84, 84, 1)
    done = False
    while not done:
    
        obs = obs.reshape(84, 84, 4)

        image_u = np.concatenate((obs[:, :, 0], obs[:, :, 2]), axis=0)
        image_d = np.concatenate((obs[:, :, 1], obs[:, :, 3]), axis=0)
        image = np.concatenate((image_u, image_d), axis=1)  # shape: (168, 168)

        plt.figure()
        plt.imshow(image)
        plt.show()
        plt.close()
        
        obs, rewards, done, info = env.step([0])


你可能感兴趣的:(atari env 同时显示4张图片 (array拼接))