强化学习Q-Learning实现机器人走迷宫

(参考学习的网址:https://www.imooc.com/article/40166,里面也比我写的更详细,也建议大家去看看)

首先有三部分代码:第一部分是绘制地图代码,第二部分是Q-Learning的代码,第三部分是运行代码

地图如下:

强化学习Q-Learning实现机器人走迷宫_第1张图片

黄色圆形 :   机器人
红色方形 :   炸弹     [reward = -1]
绿色方形 :   宝藏     [reward = +1]
其他方格 :   平地     [reward = 0]
代码如下(env.py):
# -*- coding: UTF-8 -*-

"""
Q Learning 例子的 Maze(迷宫) 环境

黄色圆形 :   机器人
红色方形 :   炸弹     [reward = -1]
绿色方形 :   宝藏     [reward = +1]
其他方格 :   平地     [reward = 0]
"""

import sys
import time

import numpy as np

# Python2 和 Python3 中 Tkinter 的名称不一样
if sys.version_info.major == 2:
    import Tkinter as tk
else:
    import tkinter as tk
# 迷宫的宽度
WIDTH = 4
# 迷宫的高度
HEIGHT = 3
# 每个方块的大小(像素值)
UNIT = 40
# 迷宫类
class Maze(tk.Tk, object):
    def __init__(self):
        super(Maze, self).__init__()
        # 上,下,左,右 四个 action(动作)
        self.action_space = ['u', 'd', 'l', 'r']
        # action 的数目
        self.n_actions = len(self.action_space)
        self.title('Q Learning')
        # Tkinter 的几何形状,宽度和高度分别乘像素值
        self.geometry('{0}x{1}'.format(WIDTH * UNIT, HEIGHT * UNIT))
        self.build_maze()

    # 创建迷宫
    def build_maze(self):
        # 创建画布 Canvas
        self.canvas = tk.Canvas(self, bg='white',
                                width=WIDTH * UNIT,
                                height=HEIGHT * UNIT)

        # 绘制横纵方格线
        for c in range(0, WIDTH * UNIT, UNIT):
            x0, y0, x1, y1 = c, 0, c, HEIGHT * UNIT
            self.canvas.create_line(x0, y0, x1, y1)
        for r in range(0, HEIGHT * UNIT, UNIT):
            x0, y0, x1, y1 = 0, r, WIDTH * UNIT, r
            self.canvas.create_line(x0, y0, x1, y1)

        # 零点(左上角)
        origin = np.array([20, 20])

        # 创建我们的探索者 机器人(robot)
        robot_center = origin + np.array([0, UNIT * 2])
        self.robot = self.canvas.create_oval(
            robot_center[0] - 15, robot_center[1] - 15,
            robot_center[0] + 15, robot_center[1] + 15,
            fill='yellow')

        # 陷阱 1
        bomb1_center = origin + UNIT
        self.bomb1 = self.canvas.create_rectangle(
            bomb1_center[0] - 15, bomb1_center[1] - 15,
            bomb1_center[0] + 15, bomb1_center[1] + 15,
            fill='red')

        # 陷阱 2
        bomb2_center = origin + np.array([UNIT * 3, UNIT])
        self.bomb2 = self.canvas.create_rectangle(
            bomb2_center[0] - 15, bomb2_center[1] - 15,
            bomb2_center[0] + 15, bomb2_center[1] + 15,
            fill='red')

        # 宝藏
        treasure_center = origin + np.array([UNIT * 3, 0])
        self.treasure = self.canvas.create_rectangle(
            treasure_center[0] - 15, treasure_center[1] - 15,
            treasure_center[0] + 15, treasure_center[1] + 15,
            fill='green')

        # 设置好上面配置的场景
        self.canvas.pack()

    # 重置(游戏重新开始,将机器人放到左下角)
    def reset(self):
        self.update()
        time.sleep(0.5)
        # 删去机器人
        self.canvas.delete(self.robot) 
        origin = np.array([20, 20])
        robot_center = origin + np.array([0, UNIT * 2])
        # 重新创建机器人
        self.robot = self.canvas.create_oval(
            robot_center[0] - 15, robot_center[1] - 15,
            robot_center[0] + 15, robot_center[1] + 15,
            fill='yellow')
        # 返回 观测(observation)
        return self.canvas.coords(self.robot)

    # 走一步(机器人实施 action)
    def step(self, action):
        #s状态
        s = self.canvas.coords(self.robot)
        #基准
        base_action = np.array([0, 0])
        if action == 0:     # 上
            if s[1] > UNIT:
                base_action[1] -= UNIT
        elif action == 1:   # 下
            if s[1] < (HEIGHT - 1) * UNIT:
                base_action[1] += UNIT
        elif action == 2:   # 右
            if s[0] < (WIDTH - 1) * UNIT:
                base_action[0] += UNIT
        elif action == 3:   # 左
            if s[0] > UNIT:
                base_action[0] -= UNIT

        # 移动机器人
        self.canvas.move(self.robot, base_action[0], base_action[1])

        # 下一个 state
        s_ = self.canvas.coords(self.robot)

        # 奖励机制
        if s_ == self.canvas.coords(self.treasure):
            # 找到宝藏,奖励为 1
            reward = 1  
            done = True
            # 终止
            s_ = 'terminal'   
            print("找到宝藏,好棒!")
        elif s_ == self.canvas.coords(self.bomb1):
            # 踩到炸弹1,奖励为 -1
            reward = -1  
            done = True
            # 终止
            s_ = 'terminal'   
            print("炸弹 1 爆炸...")
        elif s_ == self.canvas.coords(self.bomb2):
            # 踩到炸弹2,奖励为 -1
            reward = -1  
            done = True
            # 终止
            s_ = 'terminal'   
            print("炸弹 2 爆炸...")
        else:
            # 其他格子,没有奖励
            reward = 0  
            #非终止
            done = False
        return s_, reward, done

    # 调用 Tkinter 的 update 方法
    def render(self):
        time.sleep(0.1)
        self.update()

其次是Q-Learning部分

流程过程:

强化学习Q-Learning实现机器人走迷宫_第2张图片

代码如下(q_learning.py):

# -*- coding: UTF-8 -*-

"""
Q Learning 算法。做决策的部分,相当于机器人的大脑
"""

import numpy as np
import pandas as pd


class QLearning:
    def __init__(self, actions, learning_rate=0.01, discount_factor=0.9, e_greedy=0.1):
        self.actions = actions        # action 列表
        self.lr = learning_rate       # 学习速率
        self.gamma = discount_factor  # 折扣因子
        self.epsilon = e_greedy       # 贪婪度
        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float32)  # Q 表

    # 检测 q_table 中有没有这个 state
    # 如果还没有当前 state, 那我们就插入一组全 0 数据, 作为这个 state 的所有 action 的初始值
    def check_state_exist(self, state):
        if state not in self.q_table.index:
            # 插入一组全 0 数据
            self.q_table = self.q_table.append(
                pd.Series(
                    [0] * len(self.actions),
                    index=self.q_table.columns,
                    name=state,
                )
            )

    # 根据 state 来选择 action
    def choose_action(self, state):
        self.check_state_exist(state)  # 检测此 state 是否在 q_table 中存在
        # 选行为,用 Epsilon Greedy 贪婪方法
        if np.random.uniform() < self.epsilon:
            # 随机选择 action
            action = np.random.choice(self.actions)
        else:  # 选择 Q 值最高的 action
            state_action = self.q_table.loc[state, :]
            # 同一个 state, 可能会有多个相同的 Q action 值, 所以我们乱序一下
            state_action = state_action.reindex(np.random.permutation(state_action.index))
            action = state_action.idxmax()
        return action

    # 学习。更新 Q 表中的值
    def learn(self, s, a, r, s_):
        self.check_state_exist(s_)  # 检测 q_table 中是否存在 s_

        q_predict = self.q_table.loc[s, a]  # 根据 Q 表得到的 估计(predict)值

        # q_target 是现实值
        if s_ != 'terminal':  # 下个 state 不是 终止符
            q_target = r + self.gamma * self.q_table.loc[s_, :].max()
        else:
            q_target = r  # 下个 state 是 终止符

        # 更新 Q 表中 state-action 的值
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)

主函数(main.py):

# -*- coding: UTF-8 -*-

"""
游戏的主程序,调用q_learning和env
"""

from env import Maze
from q_learning import QLearning
def update():
    for episode in range(100):
        # 初始化 state(状态)
        state = env.reset()

        step_count = 0  # 记录走过的步数

        while True:
            # 更新可视化环境
            env.render()

            # RL 大脑根据 state 挑选 action
            action = RL.choose_action(str(state))

            # 探索者在环境中实施这个 action, 并得到环境返回的下一个 state, reward 和 done (是否是踩到炸弹或者找到宝藏)
            state_, reward, done = env.step(action)

            step_count += 1  # 增加步数

            # 机器人大脑从这个过渡(transition) (state, action, reward, state_) 中学习
            RL.learn(str(state), action, reward, str(state_))

            # 机器人移动到下一个 state
            state = state_

            # 如果踩到炸弹或者找到宝藏, 这回合就结束了
            if done:
                print("回合 {} 结束. 总步数 : {}\n".format(episode+1, step_count))
                break

    # 结束游戏并关闭窗口
    print('游戏结束')
    env.destroy()


if __name__ == "__main__":
    # 创建环境 env 和 RL
    env = Maze()
    RL = QLearning(actions=list(range(env.n_actions)))

    # 开始可视化环境
    env.after(100, update)
    env.mainloop()

    print('\nQ 表:')
    print(RL.q_table)

(以上为课程学习内容)

你可能感兴趣的:(python,人工智能,强化学习)