【强化学习】Q-Learning算法详解

【强化学习】Q-Learning详解

1、算法思想

QLearning是强化学习算法中值迭代的算法,Q即为Q(s,a)就是在某一时刻的 s 状态下(s∈S),采取 a (a∈A)动作能够获得收益的期望,环境会根据agent的动作反馈相应的回报reward r,所以算法的主要思想就是将State与Action构建成一张Q-table来存储Q值,然后根据Q值来选取动作获得较大的收益。

2、公式推导

举个例子如图有一个GridWorld的游戏从起点出发到达终点为胜利掉进陷阱为失败。智能体(Agent)、环境状态(environment)、奖励(reward)、动作(action)可以将问题抽象成一个马尔科夫决策过程,我们在每个格子都算是一个状态 $s_t $ , π(a|s)在s状态下采取动作a a∈A 。 P(s’|s,a)为在s状态下选择a动作转换到下一个状态s’的概率。R(s’|s,a)表示在s状态下采取a动作转移到s’的奖励reward,我们的目的很明确就是找到一条能够到达终点获得最大奖赏的策略。
这里写图片描述
所以目标就是求出累计奖赏最大的策略的期望:

Goal: m a x π E [ ∑ H t = 0 γ t R ( S t , A t , S t + 1 ) ∣ π ] m a x ⁡ π E [ ∑ t = 0 H γ t R ( S t , A t , S t + 1 ) ∣ π ] max ⁡ π E [ ∑ t = 0 H γ t R ( S t , A t , S t + 1 ) ∣ π ] maxπE[∑Ht=0γtR(St,At,St+1)∣π]max⁡πE[∑t=0HγtR(St,At,St+1)∣π] \max_πE[\sum_{t=0}^{H}γ^tR(S_t,A_t,S_{t+1}) | π] maxπE[Ht=0γtR(St,At,St+1)π]maxπE[t=0HγtR(St,At,St+1)π]maxπE[t=0HγtR(St,At,St+1)π]Q(s,a)Q(s,a)+α[r+γmaxaQ(s,a)Q(s,a)]
这里写图片描述

4、实现代码

值迭代部分

# -*- coding: utf-8 -*-
from environment import GraphicDisplay, Env

class ValueIteration:
def init(self, env):
self.env = env
# 2-d list for the value function
self.value_table = [[0.0] * env.width for _ in range(env.height)]
self.discount_factor = 0.9

# get next value function table from the current value function table
def value_iteration(self):
    next_value_table = [[0.0] * self.env.width
                                for _ in range(self.env.height)]
    for state in self.env.get_all_states():
        if state == [2, 2]:
            next_value_table[state[0]][state[1]] = 0.0
            continue
        value_list = []

        for action in self.env.possible_actions:
            next_state = self.env.state_after_action(state, action)
            reward = self.env.get_reward(state, action)
            next_value = self.get_value(next_state)
            value_list.append((reward + self.discount_factor * next_value))
        # return the maximum value(it is the optimality equation!!)
        next_value_table[state[0]][state[1]] = round(max(value_list), 2)
    self.value_table = next_value_table

# get action according to the current value function table
def get_action(self, state):
    action_list = []
    max_value = -99999

    if state == [2, 2]:
        return []

    # calculating q values for the all actions and
    # append the action to action list which has maximum q value
    for action in self.env.possible_actions:

        next_state = self.env.state_after_action(state, action)
        reward = self.env.get_reward(state, action)
        next_value = self.get_value(next_state)
        value = (reward + self.discount_factor * next_value)

        if value > max_value:
            action_list.clear()
            action_list.append(action)
            max_value = value
        elif value == max_value:
            action_list.append(action)

    return action_list

def get_value(self, state):
    return round(self.value_table[state[0]][state[1]], 2)

if name == “main”:
env = Env()
value_iteration = ValueIteration(env)
grid_world = GraphicDisplay(value_iteration)
grid_world.mainloop()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63

动态环境部分

import tkinter as tk
import time
import numpy as np
import random
from PIL import ImageTk, Image

PhotoImage = ImageTk.PhotoImage
UNIT = 100 # pixels
HEIGHT = 5 # grid height
WIDTH = 5 # grid width
TRANSITION_PROB = 1
POSSIBLE_ACTIONS = [0, 1, 2, 3] # up, down, left, right
ACTIONS = [(-1, 0), (1, 0), (0, -1), (0, 1)] # actions in coordinates
REWARDS = []

class GraphicDisplay(tk.Tk):
def init(self, value_iteration):
super(GraphicDisplay, self).init()
self.title(‘Value Iteration’)
self.geometry(’{0}x{1}’.format(HEIGHT * UNIT, HEIGHT * UNIT + 50))
self.texts = []
self.arrows = []
self.env = Env()
self.agent = value_iteration
self.iteration_count = 0
self.improvement_count = 0
self.is_moving = 0
(self.up, self.down, self.left,
self.right), self.shapes = self.load_images()
self.canvas = self._build_canvas()
self.text_reward(2, 2, “R : 1.0”)
self.text_reward(1, 2, “R : -1.0”)
self.text_reward(2, 1, “R : -1.0”)

def _build_canvas(self):
    canvas = tk.Canvas(self, bg='white',
                       height=HEIGHT * UNIT,
                       width=WIDTH * UNIT)
    # buttons
    iteration_button = tk.Button(self, text="Calculate",
                                 command=self.calculate_value)
    iteration_button.configure(width=10, activebackground="#33B5E5")
    canvas.create_window(WIDTH * UNIT * 0.13, (HEIGHT * UNIT) + 10,
                         window=iteration_button)

    policy_button = tk.Button(self, text="Print Policy",
                              command=self.print_optimal_policy)
    policy_button.configure(width=10, activebackground="#33B5E5")
    canvas.create_window(WIDTH * UNIT * 0.37, (HEIGHT * UNIT) + 10,
                         window=policy_button)

    policy_button = tk.Button(self, text="Move",
                              command=self.move_by_policy)
    policy_button.configure(width=10, activebackground="#33B5E5")
    canvas.create_window(WIDTH * UNIT * 0.62, (HEIGHT * UNIT) + 10,
                         window=policy_button)

    policy_button = tk.Button(self, text="Clear", command=self.clear)
    policy_button.configure(width=10, activebackground="#33B5E5")
    canvas.create_window(WIDTH * UNIT * 0.87, (HEIGHT * UNIT) + 10,
                         window=policy_button)

    # create grids
    for col in range(0, WIDTH * UNIT, UNIT):  # 0~400 by 80
        x0, y0, x1, y1 = col, 0, col, HEIGHT * UNIT
        canvas.create_line(x0, y0, x1, y1)
    for row in range(0, HEIGHT * UNIT, UNIT):  # 0~400 by 80
        x0, y0, x1, y1 = 0, row, HEIGHT * UNIT, row
        canvas.create_line(x0, y0, x1, y1)

    # add img to canvas
    self.rectangle = canvas.create_image(50, 50, image=self.shapes[0])
    canvas.create_image(250, 150, image=self.shapes[1])
    canvas.create_image(150, 250, image=self.shapes[1])
    canvas.create_image(250, 250, image=self.shapes[2])

    # pack all
    canvas.pack()

    return canvas

def load_images(self):
    PhotoImage = ImageTk.PhotoImage
    up = PhotoImage(Image.open("../img/up.png").resize((13, 13)))
    right = PhotoImage(Image.open("../img/right.png").resize((13, 13)))
    left = PhotoImage(Image.open("../img/left.png").resize((13, 13)))
    down = PhotoImage(Image.open("../img/down.png").resize((13, 13)))
    rectangle = PhotoImage(
        Image.open("../img/rectangle.png").resize((65, 65)))
    triangle = PhotoImage(
        Image.open("../img/triangle.png").resize((65, 65)))
    circle = PhotoImage(Image.open("../img/circle.png").resize((65, 65)))
    return (up, down, left, right), (rectangle, triangle, circle)

def clear(self):

    if self.is_moving == 0:
        self.iteration_count = 0
        self.improvement_count = 0
        for i in self.texts:
            self.canvas.delete(i)

        for i in self.arrows:
            self.canvas.delete(i)

        self.agent.value_table = [[0.0] * WIDTH for _ in range(HEIGHT)]

        x, y = self.canvas.coords(self.rectangle)
        self.canvas.move(self.rectangle, UNIT / 2 - x, UNIT / 2 - y)

def reset(self):
    self.update()
    time.sleep(0.5)
    self.canvas.delete(self.rectangle)
    return self.canvas.coords(self.rectangle)

def text_value(self, row, col, contents, font='Helvetica', size=12,
               style='normal', anchor="nw"):
    origin_x, origin_y = 85, 70
    x, y = origin_y + (UNIT * col), origin_x + (UNIT * row)
    font = (font, str(size), style)
    text = self.canvas.create_text(x, y, fill="black", text=contents,
                                   font=font, anchor=anchor)
    return self.texts.append(text)

def text_reward(self, row, col, contents, font='Helvetica', size=12,
                style='normal', anchor="nw"):
    origin_x, origin_y = 5, 5
    x, y = origin_y + (UNIT * col), origin_x + (UNIT * row)
    font = (font, str(size), style)
    text = self.canvas.create_text(x, y, fill="black", text=contents,
                                   font=font, anchor=anchor)
    return self.texts.append(text)

def rectangle_move(self, action):
    base_action = np.array([0, 0])
    location = self.find_rectangle()
    self.render()
    if action == 0 and location[0] > 0:  # up
        base_action[1] -= UNIT
    elif action == 1 and location[0] < HEIGHT - 1:  # down
        base_action[1] += UNIT
    elif action == 2 and location[1] > 0:  # left
        base_action[0] -= UNIT
    elif action == 3 and location[1] < WIDTH - 1:  # right
        base_action[0] += UNIT

    self.canvas.move(self.rectangle, base_action[0],
                     base_action[1])  # move agent

def find_rectangle(self):
    temp = self.canvas.coords(self.rectangle)
    x = (temp[0] / 100) - 0.5
    y = (temp[1] / 100) - 0.5
    return int(y), int(x)

def move_by_policy(self):

    if self.improvement_count != 0 and self.is_moving != 1:
        self.is_moving = 1
        x, y = self.canvas.coords(self.rectangle)
        self.canvas.move(self.rectangle, UNIT / 2 - x, UNIT / 2 - y)

        x, y = self.find_rectangle()
        while len(self.agent.get_action([x, y])) != 0:
            action = random.sample(self.agent.get_action([x, y]), 1)[0]
            self.after(100, self.rectangle_move(action))
            x, y = self.find_rectangle()
        self.is_moving = 0

def draw_one_arrow(self, col, row, action):
    if col == 2 and row == 2:
        return

你可能感兴趣的:(【强化学习】Q-Learning算法详解)