原文链接: dqn dueling 算法 CartPole-v0 三网络实现
上一篇: python 捕获 warning
下一篇: fast style transfer 快速风格转换 tfjs
dueling
单输出,变为双输出,一个为价值输出,一个为影响因子输出
比如在某些状态下,无论采取什么行为都将导致游戏结束
loss
score
不好的情况
loss
score
model
根据输入的状态,返回action的价值
import tensorflow as tf
import tensorflow.contrib.slim as slim
def get_net(in_x, out_num, scope, hide_num=128):
with tf.variable_scope(scope):
with slim.arg_scope(
[slim.fully_connected],
activation_fn=tf.nn.leaky_relu,
):
net = slim.fully_connected(in_x, hide_num)
net = slim.fully_connected(net, hide_num)
net = slim.fully_connected(net, hide_num)
v = slim.fully_connected(net, out_num)
a = slim.fully_connected(net, out_num)
net = v + (a - tf.reduce_mean(a, axis=1, keep_dims=True))
return net
env包装类
对环境进行包装,方便上层调用
import gym
import math
# Observation:
# Type: Box(4)
# Num Observation Min Max
# 0 Cart Position -4.8 4.8
# 1 Cart Velocity -Inf Inf
# 2 Pole Angle -24° 24°
# 3 Pole Velocity At Tip -Inf Inf
#
# Actions:
# Type: Discrete(2)
# Num Action
# 0 Push cart to the left
# 1 Push cart to the right
#
class Env:
frame_num = 4
def __init__(self):
self.env = gym.make('CartPole-v0')
# 只返回是否成功,有没有达到200步在上层控制
def step(self, action):
observation_new, reword, done, info = self.env.step(action)
theta_threshold_radians = 12 * 2 * math.pi / 360
x_threshold = 2.4
x, x_dot, theta, theta_dot = observation_new
done = any([
x < -x_threshold,
x > x_threshold,
theta < -theta_threshold_radians,
theta > theta_threshold_radians
])
if done:
reword = -2
return observation_new, reword, done, info
def reset(self):
obs = self.env.reset()
return obs
def render(self):
self.env.render()
def main():
env = Env()
obs = env.reset()
print(obs)
if __name__ == '__main__':
main()
train
使用线程技术,用后台线程刷新记忆,训练时采用三网络分别为
刷新网络 用于记忆刷新时的action选择
估值网络 训练
训练网络 训练该网络,并将网络参数每隔一段时间赋值到刷新网络和估值网络
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from collections import deque
from threading import Thread
import time
from random import sample
from model import get_net
from CartPole import Env
from tensorflow.contrib import slim
class DQN(Thread):
env = Env()
n_actions = 2
n_features = 4
lr = .01
update_step = 200
show_step = 100
test_num = 5
train_step = 30000
max_memory_size = 10000
min_memory_size = 1000
batch_size = 32
reword_decay = .5
# s_old, action, reword, s_new
memory = deque(maxlen=max_memory_size)
# 是否刷新记忆
is_refresh = True
# 记忆刷新间隔
refresh_inv = .01
epsilon = .9
epsilon_decay = .95
epsilon_decay_step = 500
min_epsilon = .001
def __init__(self):
super(DQN, self).__init__()
# self.env = Env()
# 训练网络,估值网络,记忆刷新网络
# 三个网络的状态输入
self.s_train = tf.placeholder(tf.float32, (None, self.n_features))
self.s_eval = tf.placeholder(tf.float32, (None, self.n_features))
self.s_refresh = tf.placeholder(tf.float32, (None, self.n_features))
# 对应状态的实际值
self.q_train_real = tf.placeholder(tf.float32, (None, self.n_actions))
self.q_eval_real = tf.placeholder(tf.float32, (None, self.n_actions))
self.q_refresh_real = tf.placeholder(tf.float32, (None, self.n_actions))
# 三个网络的估值输出
self.q_train_eval = get_net(self.s_train, self.n_actions, 'train_net')
self.q_eval_eval = get_net(self.s_eval, self.n_actions, 'eval_net')
self.q_refresh_eval = get_net(self.s_refresh, self.n_actions, 'refresh_net')
# 更新估值网络参数
self.eval_update_ops = [
tf.assign(old, new)
for old, new in zip(slim.get_variables('refresh_net'), slim.get_variables('train_net'))
]
# 更新刷新网络参数
self.refresh_update_ops = [
tf.assign(old, new)
for old, new in zip(slim.get_variables('eval_net'), slim.get_variables('train_net'))
]
with tf.variable_scope("loss"):
self.loss = tf.reduce_mean((self.q_train_real - self.q_train_eval) ** 2)
with tf.variable_scope('train'):
# self.train_op = tf.train.AdamOptimizer(self.lr).minimize(self.loss)
self.train_op = tf.train.AdagradOptimizer(self.lr).minimize(self.loss)
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())
self.update()
# 更新网络参数
def update(self):
self.is_refresh = False
self.sess.run(self.eval_update_ops)
self.sess.run(self.refresh_update_ops)
self.is_refresh = True
def get_batch(self, batch_size=batch_size):
# s_old, action, reword, s_new
batch = sample(self.memory, batch_size)
batch = np.array(batch)
# print(batch.shape) # (32, 10)
return batch
def run(self):
env = Env()
observation = env.reset()
# 随机选取动作的概率,每一次完整的游戏过程是不变的
random = np.random.uniform(0, self.epsilon)
step = 0
while True:
time.sleep(self.refresh_inv)
# 更新网络参数时,不进行刷新
if not self.is_refresh:
continue
action = self.choose_action(observation, random)
observation_new, reword, done, info = env.step(action)
step += 1
row = np.concatenate([observation, [action, reword], observation_new])
self.memory.append(row)
observation = observation_new
if done or step >= 200:
observation = env.reset()
random = np.random.uniform(0, self.epsilon)
step = 0
# 默认不进行随机动作选取
def choose_action(self, observation, random=0.):
if np.random.rand() < random:
return np.random.randint(0, self.n_actions)
observation = observation[np.newaxis, :]
action_value = self.sess.run(
self.q_eval_eval, {
self.s_eval: observation
}
)
return np.argmax(action_value)
# 选择记忆库中的记忆进行学习
def learn(self):
batch_memory = self.get_batch()
# print('learn ', batch_memory.shape) # learn (32, 10)
q_old_val = self.sess.run(
self.q_eval_eval, {
self.s_eval: batch_memory[:, -self.n_features:]
}
)
q_new_val = self.sess.run(
self.q_train_eval, {
self.s_train: batch_memory[:, :self.n_features]
}
)
batch_index = np.arange(self.batch_size, dtype=np.int32)
action_index = batch_memory[:, self.n_features].astype(np.int32)
reword = batch_memory[:, self.n_features + 1]
selected_q_next = np.max(q_old_val, axis=1)
q_new_val[batch_index, action_index] = reword + self.reword_decay * selected_q_next
_, loss = self.sess.run(
[self.train_op, self.loss], {
self.s_train: batch_memory[:, :self.n_features],
self.q_train_real: q_new_val
}
)
return loss
def test(self, render=False):
observation = self.env.reset()
for i in range(200):
if render:
self.env.render()
time.sleep(.01)
action = self.choose_action(observation, 0)
observation, reword, done, info = self.env.step(action)
if done:
break
return i + 1
# 开始训练
def start_train(self):
while len(self.memory) < self.min_memory_size:
time.sleep(1)
print('memory', len(self.memory))
print('memory refreshed')
loss_list = []
score_list = []
for i in range(1, self.train_step + 1):
loss = self.learn()
if not i % self.show_step:
score = 0
for j in range(self.test_num):
score += self.test()
score = score / self.test_num
score_list.append(score)
loss_list.append(loss)
print(f'step {i} score {score} loss {loss}')
# if score == 200:
# self.test(True)
if not i % self.update_step:
self.update()
print(i, 'update net')
if not i % self.epsilon_decay_step:
self.epsilon = max(self.epsilon * self.epsilon_decay, self.min_epsilon)
print(i, 'update', self.epsilon)
return loss_list, score_list
def main():
dqn = DQN()
dqn.start()
loss, score = dqn.start_train()
plt.plot(range(len(loss)), loss)
plt.show()
plt.plot(range(len(score)), score)
plt.show()
if __name__ == '__main__':
main()