利用DQN解决Gym库的CartPole问题

刚刚入门强化学习,有问题还希望多多交流~

CartPole环境介绍

关于Gym库的CartPole环境请参考大佬的博客CartPole环境介绍

DQN介绍

DQN相比于Q_Learning其实就是将Q表变成了神经网络,也就是我们在t时刻向神经网络中输入一个状态S,然后神经网络会对这个状态下所能采取的动作进行评分并通过贪婪策略选择动作A得到Q(S,A),我们回想一下Q_Learning算法,看一下Q(S,A)的更新公式在这里插入图片描述

他是需要下一个状态S’的maxQ(S’,A’)来完成当前Q(S,A)的更新,因此需要将S’输入神经网络,并让神经网络完成对A’的评分,并选择最大的Q(S’,A’)。得到maxQ(S’,A’)后,即可以完成对Q(S,A)的更新。

DQN更新方式

DQN更新和Q_learning不同,不是用上面的公式,而是使用的神经网络中设置损失函数完成更新,为了好解释我们先设置一下名字
利用DQN解决Gym库的CartPole问题_第1张图片

因为我们最终的目的是使Q(S,A)更新到最优,如果Qtarget和Qvalue之间的差距很小很小甚至为0,那么我们更新Q(S,A)的目标也就做到了,因此这里DQN更新方式采用了深度学习中损失函数的思想来更新

代码部分

整个项目的代码请见github代码
在开始之前要看一下倒立摆的状态空间和动作空间

import gym
env = gym.make('CartPole-v0')
observation = env.reset()
print(observation)#[-0.00478028 -0.02917182  0.00313288  0.03160127]状态空间为4
print(env.action_space)##Discrete(2)##动作是两个离散的动作左移(0)和右移(1)

引入必要的包

import gym
import random
import numpy as np
import tensorflow as tf
import tensorlayer as tl
import argparse
import os
import time
import matplotlib.pyplot as plt

创建网络

def crateModel(input_state):
    input_layer=tl.layers.Input(input_state)
    layer1=tl.layers.Dense(32, act=None, W_init=tf.random_uniform_initializer(0, 0.01), b_init=None)(input_layer)
    layer2 = tl.layers.Dense(16, act=None, W_init=tf.random_uniform_initializer(0, 0.01), b_init=None)(layer1)
    outputlayer = tl.layers.Dense(2, act=None, W_init=tf.random_uniform_initializer(0, 0.01), b_init=None)(layer2)
    return tl.models.Model(inputs=input_layer,outputs=outputlayer)

保存模型

def save_ckpt(model):
    tl.files.save_npz(model.trainable_weights, name='dqn_model.npz')

加载模型

def load_ckpt(model):
    tl.files.load_and_assign_npz(name="dqn_model.npz",network=model)

初始化环境并开始训练

if __name__=='__main__':

    QNetwork=crateModel([None, 4])
    QNetwork.train()
    train_weight=QNetwork.trainable_weights
    optimizer=tf.optimizers.SGD(args.lr)
    env = gym.make('CartPole-v1')
    if args.train:
        t0=time.time()
        all_episode_reward=[]
        for i in range(args.train_episodes):
            env.render()
            total_reward,done=0,False
            S = env.reset()
            while not done:
                Q=QNetwork(np.asarray([S], dtype=np.float32)).numpy()
                A=np.argmax(Q,1)
                if np.random.rand(1) < args.eps:
                    A[0] = env.action_space.sample()
                S_,reward,done,_=env.step(A[0])
                Q_=QNetwork(np.asarray([S_], dtype=np.float32)).numpy()
                maxQ_=np.max(Q_)
                targetQ=Q
                targetQ[0,A[0]]=reward+0.9*maxQ_
                with tf.GradientTape() as tape:
                    q_values=QNetwork(np.asarray([S], dtype=np.float32))
                    _loss=tl.cost.mean_squared_error(targetQ, q_values, is_mean=False)
                    grad = tape.gradient(_loss, train_weight)
                optimizer.apply_gradients(zip(grad, train_weight))
                total_reward+=reward
                S=S_
                if done==True:
                    args.eps=1./((i / 50) + 10)
                    break
            print('Training  | Episode: {}/{}  | Episode Reward: {:.4f} | Running Time: {:.4f}' \
                  .format(i, args.train_episodes, total_reward, time.time() - t0))
            if i==0:
                all_episode_reward.append(total_reward)
            else:
                all_episode_reward.append(all_episode_reward[-1] * 0.9 + total_reward * 0.1)
        save_ckpt(QNetwork)

训练完成后对模型进行测试

import time
import gym
import numpy as np
env = gym.make('CartPole-v1')
from gym_demo import load_ckpt,crateModel
t0 = time.time()
QNetwork_test=crateModel([None, 4])
load_ckpt(QNetwork_test)
num_episodes=300
QNetwork_test.eval()
for i in range(num_episodes):
    episode_time = time.time()
    S=env.reset()
    total_reword,done=0,False
    env.render()
    step_counter=0
    while not done:
        step_counter += 1
        Q=QNetwork_test(np.asarray([S],np.float32)).numpy()
        A=np.argmax(Q,1)
        S_,reward,done,_=env.step(A[0])
        total_reword+=reward
        S=S_
        if done==True:
            print(step_counter)
            break

    print('Episode: {}/{}  | Episode Reward: {:.4f} | Running Time: {:.4f}' \
          .format(i, num_episodes, total_reword, time.time() - t0))

你可能感兴趣的:(强化学习,强化学习)