gym创建自己的强化学习环境env

创建自己的用于强化学习的环境env

1,在C:\Users\xxx\anaconda3\envs\pytorch\Lib\site-packages\gym\envs\classic_control下创建环境文件MyEnv.py
2,在C:\Users\xxx\anaconda3\envs\pytorch\Lib\site-packages\gym\envs\_init_.py中注册

# 创建自己的环境
register(
    id="MyEnv-v0",
    entry_point="gym.envs.classic_control.MyEnv:MyEnv"
)

3,在C:\Users\xxx\anaconda3\envs\pytorch\Lib\site-packages\gym\envs\classic_control\_init_.py中导入自己的环境

from gym.envs.classic_control.MyEnv import MyEnv    # 创建自己的环境

4,开始写环境文件

import gym
from numpy import random
import time

class MyEnv(gym.Env):
    def __init__(self):
        self.viewer = None
        # 状态空间
        self.states = [1, 2, 3, 4, 5,
                       6, 7, 8, 9, 10,
                       11, 12, 13, 14, 15,
                       16, 17, 18, 19, 20,
                       21, 22, 23, 24, 25]
        # 动作空间
        self.actions = ['n', 'e', 's', 'w']     # 上、右、下、左
        # 回报函数
        self.rewards = dict()
        self.rewards['10_s'] = 10   # 第10个格子往南能得到奖励10
        self.rewards['14_e'] = 10   # 第14个格子往东能得到奖励10
        self.rewards['20_n'] = 10   # 第20个格子往北能得到奖励10
        self.rewards['3_e'] = -100
        self.rewards['5_w'] = -100
        self.rewards['8_e'] = -100
        self.rewards['10_w'] = -100
        self.rewards['14_n'] = -100
        self.rewards['6_s'] = -100
        self.rewards['7_s'] = -100
        self.rewards['16_n'] = -100
        self.rewards['17_n'] = -100
        self.rewards['13_w'] = -100
        self.rewards['18_s'] = -100
        self.rewards['19_s'] = -100
        self.rewards['20_s'] = -100
        self.rewards['22_e'] = -100

        # 状态转移概率,用状态_动作的模式存入字典
        self.t = dict()     # 当前格子_动作-->得到下一个格子
        self.t['1_s'] = 6
        self.t['1_e'] = 2
        self.t['2_w'] = 1
        self.t['2_s'] = 7
        self.t['2_e'] = 3
        self.t['3_w'] = 2
        self.t['3_s'] = 8
        self.t['5_s'] = 10
        self.t['6_n'] = 1
        self.t['6_e'] = 7
        self.t['7_w'] = 6
        self.t['7_n'] = 2
        self.t['7_e'] = 8
        self.t['8_w'] = 7
        self.t['8_n'] = 3
        self.t['8_s'] = 13
        self.t['10_n'] = 5
        # self.t['10_s'] = 15
        self.t['13_n'] = 8
        self.t['13_e'] = 14
        self.t['13_s'] = 18
        self.t['14_w'] = 13
        # self.t['14_e'] = 15
        self.t['14_s'] = 19
        self.t['16_e'] = 17
        self.t['16_s'] = 21
        self.t['17_w'] = 16
        self.t['17_e'] = 18
        self.t['17_s'] = 22
        self.t['18_w'] = 17
        self.t['18_e'] = 19
        self.t['18_n'] = 13
        self.t['19_w'] = 18
        self.t['19_n'] = 14
        self.t['19_e'] = 20
        self.t['20_w'] = 19
        # self.t['20_n'] = 15
        self.t['21_n'] = 16
        self.t['21_e'] = 22
        self.t['22_w'] = 21
        self.t['22_n'] = 17

    def step(self, action):
        #系统当前状态
        state = self.state
        print('当前状态state:', state)
        #将状态和动作组成的字典的键值
        key = "%d_%s" % (state, action)
        print('key :', key)
        # 出口判断,初值为FALSE
        is_terminal = False
        #状态转移
        #self.t是状态转移表,如果键值在表中,则通过表选出下一状态
        if key in self.t:
            next_state = self.t[key]
            r = -1
        # 动作的下一刻是黑色墙壁的、出口的和超出迷宫范围的都不在状态转移表中
        #如果键值不在状态转移表中,则维持当前状态
        elif key in self.rewards:
            if self.rewards[key] == -100:
                next_state = 0  # 没有下一个状态了
                r = -100
                is_terminal = True
            elif self.rewards[key] == 10:
                next_state = 15 # 找到出口
                r = 10
                is_terminal = True
        else:
            next_state = state
            r = -1
        self.state = next_state
        return next_state, r, is_terminal, {}

    def reset(self):
        s = [4, 9, 11, 12, 23, 24, 25]
        self.state = self.states[int(random.random() * (len(self.states)))]
        while self.state in s:
            self.state = self.states[int(random.random() * (len(self.states) - 1))]
        return self.state

    def close(self):
        if self.viewer:
            self.viewer.close()
            self.viewer = None

    def render(self, mode="human"):
        from gym.envs.classic_control import rendering
        width = 60
        height = 40
        edge_x = 0
        edge_y = 0
        if self.viewer is None:
            self.viewer = rendering.Viewer(300, 200)

        # 右下角                 用黑色表示墙
        self.viewer.draw_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True,
                                 color=(0, 0, 0)).add_attr(
            rendering.Transform((edge_x + width * 2, edge_y + height * 1)))
        self.viewer.draw_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True,
                                 color=(0, 0, 0)).add_attr(
            rendering.Transform((edge_x + width * 3, edge_y + height * 1)))
        self.viewer.draw_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True,
                                 color=(0, 0, 0)).add_attr(
            rendering.Transform((edge_x + width * 4, edge_y + height * 1)))
        # 左边
        self.viewer.draw_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True,
                                 color=(0, 0, 0)).add_attr(rendering.Transform((edge_x, edge_y + height * 3)))
        self.viewer.draw_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True,
                                 color=(0, 0, 0)).add_attr(
            rendering.Transform((edge_x + width * 1, edge_y + height * 3)))
        # 上边
        self.viewer.draw_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True,
                                 color=(0, 0, 0)).add_attr(
            rendering.Transform((edge_x + width * 3, edge_y + height * 4)))
        self.viewer.draw_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True,
                                 color=(0, 0, 0)).add_attr(
            rendering.Transform((edge_x + width * 3, edge_y + height * 5)))
        # 出口,用黄色表示出口
        self.viewer.draw_polygon([(0, 0), (0, height), (width, height), (width, 0)], filled=True,
                                 color=(1, 0.9, 0)).add_attr(
            rendering.Transform((edge_x + width * 4, edge_y + height * 3)))
        # 画网格
        for i in range(1, 7):
            self.viewer.draw_line((edge_x, edge_y + height * i), (edge_x + 5 * width, edge_y + height * i))  # 横线
            self.viewer.draw_line((edge_x + width * (i - 1), edge_y + height),
                                  (edge_x + width * (i - 1), edge_y + height * 6))  # 竖线

        # 人的像素位置
        self.x = [edge_x + width * 0.5, edge_x + width * 1.5, edge_x + width * 2.5, 0, edge_x + width * 4.5,
                  edge_x + width * 0.5, edge_x + width * 1.5, edge_x + width * 2.5, 0, edge_x + width * 4.5,
                  0, 0, edge_x + width * 2.5, edge_x + width * 3.5, edge_x + width * 4.5,
                  edge_x + width * 0.5, edge_x + width * 1.5, edge_x + width * 2.5, edge_x + width * 3.5,
                  edge_x + width * 4.5,
                  edge_x + width * 0.5, edge_x + width * 1.5, 0, 0, 0]

        self.y = [edge_y + height * 5.5, edge_y + height * 5.5, edge_y + height * 5.5, 0, edge_y + height * 5.5,
                  edge_y + height * 4.5, edge_y + height * 4.5, edge_y + height * 4.5, 0, edge_y + height * 4.5,
                  0, 0, edge_y + height * 3.5, edge_y + height * 3.5, edge_y + height * 3.5,
                  edge_y + height * 2.5, edge_y + height * 2.5, edge_y + height * 2.5, edge_y + height * 2.5,
                  edge_y + height * 2.5,
                  edge_y + height * 1.5, edge_y + height * 1.5, 0, 0, 0]
        # 用圆表示人
        # self.viewer.draw_circle(18,color=(0.8,0.6,0.4)).add_attr(rendering.Transform(translation=(edge_x+width/2,edge_y+height*1.5)))
        self.viewer.draw_circle(18, color=(0.8, 0.6, 0.4)).add_attr(
            rendering.Transform(translation=(self.x[self.state - 1], self.y[self.state - 1])))

        return self.viewer.render(return_rgb_array=mode == 'rgb_array')

调用自己的环境

import gym

env = gym.make("MyEnv-v0")
env.reset()
env.step()
env.render()

你可能感兴趣的:(pytorch,python,深度学习)