强化学习A3C算法
效果:
![强化学习A3C算法_第1张图片](http://img.e-com-net.com/image/info8/9479702ecd024d30bc0f1f9d248f7214.jpg)
a3c.py
import matplotlib
from matplotlib import pyplot as plt
matplotlib.rcParams['font.size'] = 18
matplotlib.rcParams['figure.titlesize'] = 18
matplotlib.rcParams['figure.figsize'] = [9, 7]
matplotlib.rcParams['font.family'] = ['KaiTi']
matplotlib.rcParams['axes.unicode_minus']=False
plt.figure()
import os
import threading
import gym
import multiprocessing
import numpy as np
from queue import Queue
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,optimizers,losses
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e :
print(e)
SEED_NUM = 1234
tf.random.set_seed(SEED_NUM)
np.random.seed(SEED_NUM)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
g_mutex = threading.Lock()
class ActorCritic(keras.Model):
""" Actor-Critic模型 """
def __init__(self, state_size, action_size):
super(ActorCritic, self).__init__()
self.state_size = state_size
self.action_size = action_size
self.dense1 = layers.Dense(128, activation='relu')
self.policy_logits = layers.Dense(action_size)
self.dense2 = layers.Dense(128, activation='relu')
self.values = layers.Dense(1)
def call(self, inputs):
x = self.dense1(inputs)
logits = self.policy_logits(x)
v = self.dense2(inputs)
values = self.values(v)
return logits, values
def record(episode,
episode_reward,
worker_idx,
global_ep_reward,
result_queue,
total_loss,
num_steps):
""" 统计工具函数 """
if global_ep_reward == 0:
global_ep_reward = episode_reward
else:
global_ep_reward = global_ep_reward * 0.99 + episode_reward * 0.01
print(
f"{episode} | "
f"Average Reward: {int(global_ep_reward)} | "
f"Episode Reward: {int(episode_reward)} | "
f"Loss: {int(total_loss / float(num_steps) * 1000) / 1000} | "
f"Steps: {num_steps} | "
f"Worker: {worker_idx}"
)
result_queue.put(global_ep_reward)
return global_ep_reward
class Memory:
""" 数据 """
def __init__(self):
self.states = []
self.actions = []
self.rewards = []
def store(self, state, action, reward):
self.states.append(state)
self.actions.append(action)
self.rewards.append(reward)
def clear(self):
self.states = []
self.actions = []
self.rewards = []
class Agent:
""" 智能体,包含了中央参数网络server """
def __init__(self):
self.opt = optimizers.Adam(1e-3)
self.server = ActorCritic(4, 2)
self.server(tf.random.normal((2, 4)))
def train(self):
res_queue = Queue()
workers = [Worker(self.server, self.opt, res_queue, i)
for i in range(10)]
for i, worker in enumerate(workers):
print("Starting worker {}".format(i))
worker.start()
returns = []
while True:
reward = res_queue.get()
if reward is not None:
returns.append(reward)
else:
break
[w.join() for w in workers]
print(returns)
plt.figure()
plt.plot(np.arange(len(returns)), returns)
plt.xlabel('回合数')
plt.ylabel('总回报')
plt.savefig('a3c-tf-cartpole.svg')
class Worker(threading.Thread):
def __init__(self, server, opt, result_queue, idx):
super(Worker, self).__init__()
self.result_queue = result_queue
self.server = server
self.opt = opt
self.client = ActorCritic(4, 2)
self.worker_idx = idx
self.env = gym.make('CartPole-v1').unwrapped
self.ep_loss = 0.0
def run(self):
mem = Memory()
for epi_counter in range(500):
current_state,info = self.env.reset(seed=SEED_NUM)
mem.clear()
ep_reward = 0.0
ep_steps = 0
done = False
while not done:
logits, _ = self.client(tf.constant(current_state[None, :],dtype=tf.float32))
probs = tf.nn.softmax(logits)
action = np.random.choice(2, p=probs.numpy()[0])
new_state, reward, done, truncated, info = self.env.step(action)
ep_reward += reward
mem.store(current_state, action, reward)
ep_steps += 1
current_state = new_state
if ep_steps >= 500 or done:
with tf.GradientTape() as tape:
total_loss = self.compute_loss(done, new_state, mem)
grads = tape.gradient(total_loss, self.client.trainable_weights)
global g_mutex
g_mutex.acquire()
self.opt.apply_gradients(zip(grads,self.server.trainable_weights))
g_mutex.release()
g_mutex.acquire()
self.client.set_weights(self.server.get_weights())
g_mutex.release()
mem.clear()
self.result_queue.put(ep_reward)
print(f"thread worker_idx : {self.worker_idx}, episode reward : {ep_reward}")
break
self.result_queue.put(None)
def compute_loss(self,
done,
new_state,
memory,
gamma=0.99):
if done:
reward_sum = 0.
else:
reward_sum = self.client(tf.constant(new_state[None, :],dtype=tf.float32))[-1].numpy()[0]
discounted_rewards = []
for reward in memory.rewards[::-1]:
reward_sum = reward + gamma * reward_sum
discounted_rewards.append(reward_sum)
discounted_rewards.reverse()
logits, values = self.client(tf.constant(np.vstack(memory.states), dtype=tf.float32))
advantage = tf.constant(np.array(discounted_rewards)[:, None], dtype=tf.float32) - values
value_loss = advantage ** 2
policy = tf.nn.softmax(logits)
policy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=memory.actions, logits=logits)
policy_loss = policy_loss * tf.stop_gradient(advantage)
entropy = tf.nn.softmax_cross_entropy_with_logits(labels=policy, logits=logits)
policy_loss = policy_loss - 0.01 * entropy
total_loss = tf.reduce_mean((0.5 * value_loss + policy_loss))
return total_loss
if __name__ == '__main__':
master = Agent()
master.train()