DQN:深度Q-网络

基本概念

DQN

DQN全名叫Deep Q-Leaning Network,中文名叫深度Q-网络。DQN算法的基本思路来源于Q-Learning,不同于Q-learning,DQN的Q值不是直接通过状态值s和动作a来计算的,而是通过神经网络来计算的。

DQN算法本质上属于Q-Learning算法,在策略选择上和Q-Learning保持一致,采用 ϵ − g r e e d y \epsilon-greedy ϵgreedy策略。在Q-learning的基础上,DQN提出了两个技巧使得Q网络的更新迭代更稳定:

1、经验回放:DQN使用经验池对多条经验 ( s , a , r , s ′ ) (s, a, r, s') (s,a,r,s)进行保存,在训练的时候,随机从经验池中抽取一定数量的数据来进行训练,这样就可以不停的优化网络模型。

2、固定Q目标 Fixed-Q-Target:主要解决算法训练不稳定的问题。复制一个和原来Q网络结构一样的Target Q网络,用于计算Q目标值。DQN中有两个结构相同但是参数不同的网络,当前值( p r e d i c t Q predictQ predictQ)网络用于预测估计的Q值,目标值( t a r g e t Q targetQ targetQ)网络用于预测现实的Q值。当前值网络使用最新的参数,目标值网络会使用很久之前的参数。

其中, t a r g e t Q targetQ targetQ值的计算公式: t a r g e t Q = r + γ ∗ m a x Q ( s ′ , a ∗ ; θ ) targetQ=r+γ∗maxQ(s',a^*;θ) targetQ=r+γmaxQ(s,a;θ
p r e d i c t Q predictQ predictQ的计算公式: p r e d i c t Q = Q ( s , a ; θ ) predictQ=Q(s,a;\theta) predictQ=Q(s,a;θ)

如下图所示,使用均方差损失函数 1 m ∑ j = 1 m ( t a r g e t Q − p r e d i c t Q ) 2 \frac 1 m \sum_{j=1}^{m}(targetQ- predictQ)^2 m1j=1m(targetQpredictQ)2,通过神经网络的梯度反向传播来更新 p r e d i c t Q predictQ predictQ网络的所有参数 θ \theta θ。并且每隔N时间步长,拷贝 p r e d i c t Q predictQ predictQ网络的所有参数到 t a r g e t Q targetQ targetQ网络中。

DQN:深度Q-网络_第1张图片
简而言之,DQN使用 ϵ − g r e e d y \epsilon-greedy ϵgreedy策略来选择动作并执行,采用经验回收机制,使用经验池存储(状态,动作,价值,下一个状态)信息,存储完成之后,以批量的形式获取数据,使用均方差损失函数,采用梯度随机下降法更新当前值( p r e d i c t Q predictQ predictQ)网络的参数,进行当前值网络的训练,并每隔N时间步长,将参数同步到目标值( t a r g e t Q targetQ targetQ)网络。

DQN与Q-Learning的区别:

整体来说,DQN 与Q-Learning的目标价值以及价值的更新方式都非常相似。但是,DQN将Q-Learning与深度学习结合,用深度网络来近似动作价值函数,而Q-Learning则是采用表格存储;DQN 采用了经验回放的训练方法,从历史数据中随机采样,而Q-Learning直接采用下一个状态的数据进行学习。

DQN的算法如下图所示。
DQN:深度Q-网络_第2张图片
DQN:深度Q-网络_第3张图片

在上述代码中, Q ( ϕ j , a j ; θ ) Q(\phi_j,a_j;\theta) Q(ϕj,aj;θ)为当前值( p r e d i c t Q predictQ predictQ)网络预测的Q值, y i = r j + γ m a x a ′ Q ( ϕ j + 1 , a ′ ; θ ) y_i=r_j + \gamma max_{a'}Q(\phi_{j+1},a';\theta) yi=rj+γmaxaQ(ϕj+1,a;θ)为目标值( t a r g e t Q targetQ targetQ)网络预测的Q值。

测试代码

下面是一个DQN贪吃蛇的例子

import random
import sys
from collections import deque

import numpy as np
import pygame as pg
import tensorflow as tf
import cv2 as cv

# 参数
# 游戏帧率
FPS = 5
# 窗口宽度、高度
WINDOW_WIDTH, WINDOW_HEIGHT = 640, 480
# 组成大小
CELL_SIZE = 40
CELL_WIDTH, CELL_HEIGHT = WINDOW_WIDTH // CELL_SIZE, WINDOW_HEIGHT // CELL_SIZE
# 常用颜色
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
DARK_GREEN = (0, 155, 0)
GREEN = (0, 255, 0)
DARK_GRAY = (60, 60, 60)
RED = (255, 0, 0)
# 方向
UP = "up"
DOWN = "down"
LEFT = "left"
RIGHT = "right"
# 神经网络的输出
MOVE_UP = [1, 0, 0, 0]
MOVE_DOWN = [0, 1, 0, 0]
MOVE_LEFT = [0, 0, 1, 0]
MOVE_RIGHT = [0, 0, 0, 1]


def check_for_key_press():
    if len(pg.event.get(pg.QUIT)) > 0:
        pg.quit()
        sys.exit()
    key_up_events = pg.event.get(pg.KEYUP)
    if len(key_up_events) == 0:
        return None
    if key_up_events[0].key == pg.K_ESCAPE:
        pg.quit()
        sys.exit()
    return key_up_events[0].key


def show_start_screen():
    title_font = pg.font.Font("freesansbold.ttf", 100)
    title_surface1 = title_font.render("snake", True, WHITE, DARK_GREEN)
    title_surface2 = title_font.render("snake", True, GREEN)
    degree1 = 0
    degree2 = 0
    press_key_font = pg.font.Font("freesansbold.ttf", 18)
    press_key_surface = press_key_font.render("press a key to play", True, DARK_GRAY)
    while True:
        screen.fill(BLACK)
        # draw snake word
        rotated_surface1 = pg.transform.rotate(title_surface1, degree1)
        rotated_rect1 = rotated_surface1.get_rect()
        rotated_rect1.center = (WINDOW_WIDTH / 2, WINDOW_HEIGHT / 2)
        screen.blit(rotated_surface1, rotated_rect1)
        rotated_surface2 = pg.transform.rotate(title_surface2, degree2)
        rotated_rect2 = rotated_surface2.get_rect()
        rotated_rect2.center = (WINDOW_WIDTH / 2, WINDOW_HEIGHT / 2)
        screen.blit(rotated_surface2, rotated_rect2)
        # draw press key word
        press_key_rect = press_key_surface.get_rect()
        press_key_rect.topleft = (WINDOW_WIDTH - 200, WINDOW_HEIGHT - 30)
        screen.blit(press_key_surface, press_key_rect)
        if check_for_key_press():
            pg.event.get()
            return
        pg.display.update()
        clock.tick(FPS)
        degree1 += 3
        degree2 += 3


def test_not_ok(temp, worm):
    for body in worm:
        if temp['x'] == body['x'] and temp['y'] == body['y']:
            return True
    return False


def get_random_location(worm):
    temp = {'x': random.randint(0, CELL_WIDTH - 1), 'y': random.randint(0, CELL_HEIGHT - 1)}
    while test_not_ok(temp, worm):
        temp = {'x': random.randint(0, CELL_WIDTH - 1), 'y': random.randint(0, CELL_HEIGHT - 1)}
    return temp


# 检查贪吃蛇是否出现180度掉头
def examine_direction(pre_direction):
    if direction == UP and pre_direction == DOWN:
        return False
    if direction == DOWN and pre_direction == UP:
        return False
    if direction == LEFT and pre_direction == RIGHT:
        return False
    if direction == RIGHT and pre_direction == LEFT:
        return False
    return True


def draw_grid():
    for x in range(0, WINDOW_WIDTH, CELL_SIZE):
        pg.draw.line(screen, DARK_GRAY, (x, 0), (x, WINDOW_HEIGHT))
    for y in range(0, WINDOW_HEIGHT, CELL_SIZE):
        pg.draw.line(screen, DARK_GRAY, (0, y), (WINDOW_WIDTH, y))


def draw_worm_coord():
    for body in worm_coord:
        x = body['x'] * CELL_SIZE
        y = body['y'] * CELL_SIZE
        body_rect = pg.Rect(x, y, CELL_SIZE, CELL_SIZE)
        pg.draw.rect(screen, DARK_GREEN, body_rect)
        body_inner_rect = pg.Rect(x + 4, y + 4, CELL_SIZE - 8, CELL_SIZE - 8)
        pg.draw.rect(screen, GREEN, body_inner_rect)


def draw_apple():
    x = apple['x'] * CELL_SIZE
    y = apple['y'] * CELL_SIZE
    apple_rect = pg.Rect(x, y, CELL_SIZE, CELL_SIZE)
    pg.draw.rect(screen, WHITE, apple_rect)


def run_game(action=None):
    global direction, worm_coord, head, apple
    pre_direction = direction
    if action == MOVE_UP and direction != DOWN:
        direction = UP
    elif action == MOVE_DOWN and direction != UP:
        direction = DOWN
    elif action == MOVE_LEFT and direction != RIGHT:
        direction = LEFT
    elif action == MOVE_RIGHT and direction != LEFT:
        direction = RIGHT
    for event in pg.event.get():
        if event.type == pg.QUIT:
            pg.quit()
            sys.exit()
        elif event.type == pg.KEYUP:
            if (event.key == pg.K_LEFT or event.key == pg.K_a) and direction != RIGHT:
                direction = LEFT
            elif (event.key == pg.K_RIGHT or event.key == pg.K_d) and direction != LEFT:
                direction = RIGHT
            elif (event.key == pg.K_UP or event.key == pg.K_w) and direction != DOWN:
                direction = UP
            elif (event.key == pg.K_DOWN or event.key == pg.K_s) and direction != UP:
                direction = DOWN
            elif event.key == pg.K_ESCAPE:
                pg.quit()
                sys.exit()

    reward = 0
    # 检测贪吃蛇有没有碰到墙壁
    if worm_coord[head]['x'] == -1 or worm_coord[head]['x'] == CELL_WIDTH \
            or worm_coord[head]['y'] == -1 or worm_coord[head]['y'] == CELL_HEIGHT:
        worm_coord = [{'x': start_x, 'y': start_y},
                      {'x': start_x - 1, 'y': start_y},
                      {'x': start_x - 2, 'y': start_y}]
        direction = RIGHT
        screen_image = pg.surfarray.array3d(pg.display.get_surface())
        reward = -1
        return reward, screen_image
    # 检测贪吃蛇有没有碰到自己
    for worm_body in worm_coord[1:]:
        if worm_body['x'] == worm_coord[head]['x'] and worm_body['y'] == worm_coord[head]['y']:
            worm_coord = [{'x': start_x, 'y': start_y},
                          {'x': start_x - 1, 'y': start_y},
                          {'x': start_x - 2, 'y': start_y}]
            direction = RIGHT
            screen_image = pg.surfarray.array3d(pg.display.get_surface())
            reward = -1
            return reward, screen_image
    # 检测贪吃蛇有没有吃到苹果
    # 如果吃到苹果,不用删除末尾,相当于增加一节
    if worm_coord[head]['x'] == apple['x'] and worm_coord[head]['y'] == apple['y']:
        reward = 1
        apple = get_random_location(worm_coord)
    # 如果没有吃到苹果,删除末尾一节
    else:
        del worm_coord[-1]
    # 贪吃蛇移动逻辑
    # 如果贪吃蛇出现180度旋转,则方向和原来方向保持不变
    if not examine_direction(pre_direction):
        direction = pre_direction
    # 根据贪吃蛇方向确定新头部的位置
    new_head = {}
    if direction == UP:
        new_head = {'x': worm_coord[head]['x'], 'y': worm_coord[head]['y'] - 1}
    elif direction == DOWN:
        new_head = {'x': worm_coord[head]['x'], 'y': worm_coord[head]['y'] + 1}
    elif direction == LEFT:
        new_head = {'x': worm_coord[head]['x'] - 1, 'y': worm_coord[head]['y']}
    elif direction == RIGHT:
        new_head = {'x': worm_coord[head]['x'] + 1, 'y': worm_coord[head]['y']}
    worm_coord.insert(0, new_head)
    screen.fill(BLACK)
    draw_grid()
    draw_apple()
    draw_worm_coord()
    pg.display.update()
    clock.tick(FPS)
    screen_image = pg.surfarray.array3d(pg.display.get_surface())
    return reward, screen_image


def run():
    global screen, clock
    pg.init()
    screen = pg.display.set_mode((WINDOW_WIDTH, WINDOW_HEIGHT))
    clock = pg.time.Clock()
    show_start_screen()
    # while True:
    #     run_game()
    #     clock.tick(FPS)
    #     show_game_over_screen(screen)


start_x, start_y = 5, 5
head = 0
worm_coord = [{'x': start_x, 'y': start_y},
              {'x': start_x - 1, 'y': start_y},
              {'x': start_x - 2, 'y': start_y}]
direction = RIGHT
apple = get_random_location(worm_coord)
# run()

# 训练参数
LEARNING_RATE = 0.99
INITIAL_EPSILON = 1.0
FINAL_EPSILON = 0.05
EXPLORE = 50000
OBSERVE = 100
REPLAY_MEMORY = 1024
BATCH = 14

tf.compat.v1.disable_eager_execution()
input_image = tf.compat.v1.placeholder("float", [None, 160, 120, 4])
action = tf.compat.v1.placeholder("float", [None, 4])


def convolutional_neural_network(input_image):
    weights = {"w_conv1": tf.Variable(tf.zeros([8, 8, 4, 32])),
               "w_conv2": tf.Variable(tf.zeros([4, 4, 32, 64])),
               "w_conv3": tf.Variable(tf.zeros([3, 3, 64, 64])),
               "w_fc4": tf.Variable(tf.zeros([128, 64])),
               "w_out": tf.Variable(tf.zeros([64, 4]))}
    bias = {"b_conv1": tf.Variable(tf.zeros([32])),
            "b_conv2": tf.Variable(tf.zeros([64])),
            "b_conv3": tf.Variable(tf.zeros([64])),
            "b_fc4": tf.Variable(tf.zeros([64])),
            "b_out": tf.Variable(tf.zeros([4]))}
    conv1 = tf.nn.relu(tf.nn.conv2d(input_image, weights["w_conv1"], strides=[1, 4, 4, 1], padding="VALID")
                       + bias["b_conv1"])
    conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
    conv2 = tf.nn.relu(tf.nn.conv2d(conv1, weights["w_conv2"], strides=[1, 2, 2, 1], padding="VALID")
                       + bias["b_conv2"])
    conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")

    conv3 = tf.nn.relu(tf.nn.conv2d(conv2, weights["w_conv3"], strides=[1, 1, 1, 1], padding="VALID")
                       + bias["b_conv3"])
    conv3 = tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
    conv3_flat = tf.reshape(conv3, [-1, 128])
    fc4 = tf.nn.relu(tf.matmul(conv3_flat, weights["w_fc4"]) + bias["b_fc4"])
    out = tf.matmul(fc4, weights["w_out"] + bias["b_out"])
    return out


def train(input_image):
    tf.compat.v1.disable_eager_execution()
    predict_action = convolutional_neural_network(input_image)
    argmax = tf.compat.v1.placeholder("float", [None, 4])
    gt = tf.compat.v1.placeholder("float", [None])
    # 定义均方差损失函数的计算过程
    action = tf.reduce_sum(tf.multiply(predict_action, argmax))
    cost = tf.reduce_mean(tf.square(action - gt))
    # 定义机器学习过程
    optimizer = tf.compat.v1.train.AdamOptimizer(1e-2).minimize(cost)
    run()
    D = deque()
    _, image = run_game()
    image = cv.cvtColor(cv.resize(image, (120, 160)), cv.COLOR_BGR2GRAY)
    ret, image = cv.threshold(image, 1, 255, cv.THRESH_BINARY)
    input_image_data = np.stack((image, image, image, image), axis=2)
    with tf.compat.v1.Session() as sess:
        sess.run(tf.compat.v1.initialize_all_variables())
        # saver = tf.train.Saver()
        n = 0
        epsilon = INITIAL_EPSILON
        while True:
            action_t = predict_action.eval(feed_dict={input_image: [input_image_data]})[0]
            argmax_t = np.zeros([4], dtype=np.int)
            # 每个状态以epsilon的概率进行探索
            if random.random() <= epsilon:
                max_index = random.randrange(4)
            else:
                max_index = np.argmax(action_t)
            argmax_t[max_index] = 1
            if epsilon > FINAL_EPSILON:
                epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE
            reward, image = run_game(list(argmax_t))
            image = cv.cvtColor(cv.resize(image, (120, 160)), cv.COLOR_BGR2GRAY)
            ret, image = cv.threshold(image, 1, 255, cv.THRESH_BINARY)
            image = np.reshape(image, (160, 120, 1))
            input_image_data1 = np.append(image, input_image_data[:, :, 0: 3], axis=2)
            D.append((input_image_data, argmax_t, reward, input_image_data1))
            if len(D) > REPLAY_MEMORY:
                D.popleft()
            if n > OBSERVE:
                min_batch = random.sample(D, BATCH)
                input_image_data_batch = [d[0] for d in min_batch]
                argmax_batch = [d[1] for d in min_batch]
                reward_batch = [d[2] for d in min_batch]
                input_image_data1_batch = [d[3] for d in min_batch]
                gt_batch = []
                out_batch = predict_action.eval(feed_dict={input_image: input_image_data1_batch})
                for i in range(0, len(min_batch)):
                    gt_batch.append(reward_batch[i] + LEARNING_RATE * np.max(out_batch[i]))
                # 利用梯度反向传播更新模型参数
                optimizer.run(feed_dict={gt: gt_batch, argmax: argmax_batch, input_image: input_image_data_batch})
            input_image_data = input_image_data1
            n = n + 1
            print(n, "epsilon:", epsilon, " ", "action:", max_index, " ", "reward:", reward)


train(input_image)

测试结果

DQN:深度Q-网络_第4张图片
测试中,贪吃蛇进行了50000轮训练,每一次训练,贪吃蛇通过策略函数选取合适的动作,并将结果存储到经验池中,也就是上述代码中的双向队列Q。贪吃蛇已基本具备躲避边缘和寻找最优路径吃苹果的能力。

你可能感兴趣的:(强化学习,python)