在前面的博客中,我们使用了DQN等算法训练了agent并得到了较高的分数。DQN中的神经网络是输出的动作Q值,然后通过哪个Q值更大,就采取相应的动作,可我们为什么不直接让神经网络输出动作(概率),一步到位呢。而Policy Gradient就可以一步到位。
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import parl
import numpy as np
import gym
from parl.utils import logger
from paddle.distribution import Categorical
LEARNING_RATE = 1e-3
class Model(parl.Model):
def __init__(self, obs_dim, act_dim):
super().__init__()
hid1_size = act_dim * 10
self.fc1 = nn.Linear(obs_dim, hid1_size)
self.fc2 = nn.Linear(hid1_size, act_dim)
def forward(self, obs):
out = F.tanh(self.fc1(obs))
out = F.softmax(self.fc2(out))
return out
class PolicyGradient(parl.Algorithm):
def __init__(self, model, lr=None):
self.model = model
assert isinstance(lr, float)
self.optimizer = paddle.optimizer.Adam(learning_rate=lr, parameters=model.parameters())
def predict(self, obs):
return self.model(obs)
def learn(self, obs, act, reward):
# act_prob = self.model(obs)
# log_prob = F.cross_entropy(act_prob, act)
# loss = log_prob.mean()
# self.optimizer.clear_grad()
# loss.backward()
# self.optimizer.step()
prob = self.model(obs)
log_prob = Categorical(prob).log_prob(act)
loss = paddle.mean(-1 * log_prob * reward)
self.optimizer.clear_grad()
loss.backward()
self.optimizer.step()
return loss
class Agent(parl.Agent):
def __init__(self, algorithm, obs_dim, act_dim):
super().__init__(algorithm)
self.obs_dim = obs_dim
self.act_dim = act_dim
def sample(self, obs):
obs = paddle.to_tensor(obs, dtype='float32')
act_prob = self.alg.predict(obs)
act_prob = np.squeeze(act_prob, axis=0)
act = np.random.choice(range(self.act_dim), p=act_prob.numpy())
return act
def predict(self, obs):
obs = paddle.to_tensor(obs, dtype='float32')
act_prob = self.alg.predict(obs)
act = np.argmax(act_prob)
return act
def learn(self, obs, act, reward):
act = np.expand_dims(act, axis=-1)
reward = np.expand_dims(reward, axis=-1)
obs = paddle.to_tensor(obs, dtype='float32')
act = paddle.to_tensor(act, dtype='int32')
reward = paddle.to_tensor(reward, dtype='float32')
loss = self.alg.learn(obs, act, reward)
return loss.numpy()[0]
def run_episode(env, agent):
obs_list, action_list, reward_list = [], [], []
obs = env.reset()
while True:
obs_list.append(obs)
action = agent.sample(obs) # 采样动作
action_list.append(action)
obs, reward, done, info = env.step(action)
reward_list.append(reward)
if done:
break
return obs_list, action_list, reward_list
# 评估 agent, 跑 5 个episode,总reward求平均
def evaluate(env, agent, render=False):
eval_reward = []
for i in range(5):
obs = env.reset()
episode_reward = 0
while True:
action = agent.predict(obs) # 选取最优动作
obs, reward, isOver, _ = env.step(action)
episode_reward += reward
if render:
env.render()
if isOver:
break
eval_reward.append(episode_reward)
return np.mean(eval_reward)
# 根据一个episode的每个step的reward列表,计算每一个Step的Gt
def calc_reward_to_go(reward_list, gamma=1.0):
for i in range(len(reward_list) - 2, -1, -1):
# G_t = r_t + γ·r_t+1 + ... = r_t + γ·G_t+1
reward_list[i] += gamma * reward_list[i + 1] # Gt
return np.array(reward_list)
# 创建环境
env = gym.make('CartPole-v0')
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n
logger.info('obs_dim {}, act_dim {}'.format(obs_dim, act_dim))
# 根据parl框架构建agent
model = Model(obs_dim, act_dim)
alg = PolicyGradient(model, lr=LEARNING_RATE)
agent = Agent(alg, obs_dim=obs_dim, act_dim=act_dim)
# 加载模型
# if os.path.exists('./model.ckpt'):
# agent.restore('./model.ckpt')
# run_episode(env, agent, train_or_test='test', render=True)
# exit()
for i in range(1000):
obs_list, action_list, reward_list = run_episode(env, agent)
if i % 10 == 0:
logger.info("Episode {}, Reward Sum {}.".format(
i, sum(reward_list)))
batch_obs = np.array(obs_list)
batch_action = np.array(action_list)
batch_reward = calc_reward_to_go(reward_list)
agent.learn(batch_obs, batch_action, batch_reward)
if (i + 1) % 100 == 0:
total_reward = evaluate(env, agent, render=False) # render=True 查看渲染效果,需要在本地运行,AIStudio无法显示
logger.info('Test reward: {}'.format(total_reward))
# 保存模型到文件 ./model.ckpt
agent.save('./model.ckpt')
输出:策略梯度算法收敛的特别快,训练了几十秒就基本收敛了,下面是agent的训练表现
[12-02 10:05:40 MainThread @3974053785.py:128] obs_dim 4, act_dim 2
[12-02 10:05:40 MainThread @machine_info.py:88] nvidia-smi -L found gpu count: 1
[12-02 10:05:40 MainThread @3974053785.py:145] Episode 0, Reward Sum 34.0.
[12-02 10:05:40 MainThread @3974053785.py:145] Episode 10, Reward Sum 12.0.
[12-02 10:05:41 MainThread @3974053785.py:145] Episode 20, Reward Sum 11.0.
[12-02 10:05:41 MainThread @3974053785.py:145] Episode 30, Reward Sum 18.0.
[12-02 10:05:41 MainThread @3974053785.py:145] Episode 40, Reward Sum 20.0.
[12-02 10:05:41 MainThread @3974053785.py:145] Episode 50, Reward Sum 31.0.
[12-02 10:05:41 MainThread @3974053785.py:145] Episode 60, Reward Sum 40.0.
[12-02 10:05:41 MainThread @3974053785.py:145] Episode 70, Reward Sum 16.0.
[12-02 10:05:42 MainThread @3974053785.py:145] Episode 80, Reward Sum 16.0.
[12-02 10:05:42 MainThread @3974053785.py:145] Episode 90, Reward Sum 22.0.
[12-02 10:05:42 MainThread @3974053785.py:154] Test reward: 45.2
[12-02 10:05:42 MainThread @3974053785.py:145] Episode 100, Reward Sum 13.0.
[12-02 10:05:42 MainThread @3974053785.py:145] Episode 110, Reward Sum 14.0.
[12-02 10:05:43 MainThread @3974053785.py:145] Episode 120, Reward Sum 68.0.
[12-02 10:05:43 MainThread @3974053785.py:145] Episode 130, Reward Sum 28.0.
[12-02 10:05:43 MainThread @3974053785.py:145] Episode 140, Reward Sum 25.0.
[12-02 10:05:43 MainThread @3974053785.py:145] Episode 150, Reward Sum 55.0.
[12-02 10:05:44 MainThread @3974053785.py:145] Episode 160, Reward Sum 87.0.
[12-02 10:05:44 MainThread @3974053785.py:145] Episode 170, Reward Sum 35.0.
[12-02 10:05:44 MainThread @3974053785.py:145] Episode 180, Reward Sum 59.0.
[12-02 10:05:44 MainThread @3974053785.py:145] Episode 190, Reward Sum 40.0.
[12-02 10:05:45 MainThread @3974053785.py:154] Test reward: 81.0
[12-02 10:05:45 MainThread @3974053785.py:145] Episode 200, Reward Sum 63.0.
[12-02 10:05:45 MainThread @3974053785.py:145] Episode 210, Reward Sum 22.0.
[12-02 10:05:45 MainThread @3974053785.py:145] Episode 220, Reward Sum 86.0.
[12-02 10:05:46 MainThread @3974053785.py:145] Episode 230, Reward Sum 65.0.
[12-02 10:05:46 MainThread @3974053785.py:145] Episode 240, Reward Sum 24.0.
[12-02 10:05:46 MainThread @3974053785.py:145] Episode 250, Reward Sum 26.0.
[12-02 10:05:46 MainThread @3974053785.py:145] Episode 260, Reward Sum 34.0.
[12-02 10:05:47 MainThread @3974053785.py:145] Episode 270, Reward Sum 70.0.
[12-02 10:05:47 MainThread @3974053785.py:145] Episode 280, Reward Sum 37.0.
[12-02 10:05:47 MainThread @3974053785.py:145] Episode 290, Reward Sum 43.0.
[12-02 10:05:48 MainThread @3974053785.py:154] Test reward: 98.0
[12-02 10:05:48 MainThread @3974053785.py:145] Episode 300, Reward Sum 33.0.
[12-02 10:05:48 MainThread @3974053785.py:145] Episode 310, Reward Sum 49.0.
[12-02 10:05:49 MainThread @3974053785.py:145] Episode 320, Reward Sum 66.0.
[12-02 10:05:49 MainThread @3974053785.py:145] Episode 330, Reward Sum 54.0.
[12-02 10:05:49 MainThread @3974053785.py:145] Episode 340, Reward Sum 98.0.
[12-02 10:05:50 MainThread @3974053785.py:145] Episode 350, Reward Sum 81.0.
[12-02 10:05:50 MainThread @3974053785.py:145] Episode 360, Reward Sum 78.0.
[12-02 10:05:50 MainThread @3974053785.py:145] Episode 370, Reward Sum 112.0.
[12-02 10:05:51 MainThread @3974053785.py:145] Episode 380, Reward Sum 40.0.
[12-02 10:05:51 MainThread @3974053785.py:145] Episode 390, Reward Sum 37.0.
[12-02 10:05:52 MainThread @3974053785.py:154] Test reward: 104.4
[12-02 10:05:52 MainThread @3974053785.py:145] Episode 400, Reward Sum 74.0.
[12-02 10:05:52 MainThread @3974053785.py:145] Episode 410, Reward Sum 44.0.
[12-02 10:05:52 MainThread @3974053785.py:145] Episode 420, Reward Sum 39.0.
[12-02 10:05:53 MainThread @3974053785.py:145] Episode 430, Reward Sum 104.0.
[12-02 10:05:53 MainThread @3974053785.py:145] Episode 440, Reward Sum 22.0.
[12-02 10:05:54 MainThread @3974053785.py:145] Episode 450, Reward Sum 92.0.
[12-02 10:05:54 MainThread @3974053785.py:145] Episode 460, Reward Sum 16.0.
[12-02 10:05:54 MainThread @3974053785.py:145] Episode 470, Reward Sum 16.0.
[12-02 10:05:55 MainThread @3974053785.py:145] Episode 480, Reward Sum 117.0.
[12-02 10:05:55 MainThread @3974053785.py:145] Episode 490, Reward Sum 72.0.
[12-02 10:05:56 MainThread @3974053785.py:154] Test reward: 198.8
[12-02 10:05:56 MainThread @3974053785.py:145] Episode 500, Reward Sum 20.0.
[12-02 10:05:57 MainThread @3974053785.py:145] Episode 510, Reward Sum 62.0.
[12-02 10:05:57 MainThread @3974053785.py:145] Episode 520, Reward Sum 17.0.
[12-02 10:05:58 MainThread @3974053785.py:145] Episode 530, Reward Sum 56.0.
[12-02 10:05:59 MainThread @3974053785.py:145] Episode 540, Reward Sum 101.0.
[12-02 10:05:59 MainThread @3974053785.py:145] Episode 550, Reward Sum 178.0.
[12-02 10:06:00 MainThread @3974053785.py:145] Episode 560, Reward Sum 57.0.
[12-02 10:06:01 MainThread @3974053785.py:145] Episode 570, Reward Sum 158.0.
[12-02 10:06:01 MainThread @3974053785.py:145] Episode 580, Reward Sum 72.0.
[12-02 10:06:02 MainThread @3974053785.py:145] Episode 590, Reward Sum 161.0.
[12-02 10:06:03 MainThread @3974053785.py:154] Test reward: 182.2
[12-02 10:06:03 MainThread @3974053785.py:145] Episode 600, Reward Sum 113.0.
[12-02 10:06:04 MainThread @3974053785.py:145] Episode 610, Reward Sum 112.0.
[12-02 10:06:04 MainThread @3974053785.py:145] Episode 620, Reward Sum 61.0.
[12-02 10:06:05 MainThread @3974053785.py:145] Episode 630, Reward Sum 143.0.
[12-02 10:06:06 MainThread @3974053785.py:145] Episode 640, Reward Sum 156.0.
[12-02 10:06:07 MainThread @3974053785.py:145] Episode 650, Reward Sum 150.0.
[12-02 10:06:08 MainThread @3974053785.py:145] Episode 660, Reward Sum 167.0.
[12-02 10:06:09 MainThread @3974053785.py:145] Episode 670, Reward Sum 200.0.
[12-02 10:06:10 MainThread @3974053785.py:145] Episode 680, Reward Sum 200.0.
[12-02 10:06:10 MainThread @3974053785.py:145] Episode 690, Reward Sum 164.0.
[12-02 10:06:12 MainThread @3974053785.py:154] Test reward: 199.8
[12-02 10:06:12 MainThread @3974053785.py:145] Episode 700, Reward Sum 126.0.
[12-02 10:06:13 MainThread @3974053785.py:145] Episode 710, Reward Sum 164.0.
[12-02 10:06:13 MainThread @3974053785.py:145] Episode 720, Reward Sum 200.0.
[12-02 10:06:14 MainThread @3974053785.py:145] Episode 730, Reward Sum 92.0.
[12-02 10:06:15 MainThread @3974053785.py:145] Episode 740, Reward Sum 200.0.
[12-02 10:06:16 MainThread @3974053785.py:145] Episode 750, Reward Sum 197.0.
[12-02 10:06:17 MainThread @3974053785.py:145] Episode 760, Reward Sum 200.0.
[12-02 10:06:18 MainThread @3974053785.py:145] Episode 770, Reward Sum 178.0.
[12-02 10:06:19 MainThread @3974053785.py:145] Episode 780, Reward Sum 200.0.
[12-02 10:06:20 MainThread @3974053785.py:145] Episode 790, Reward Sum 200.0.
[12-02 10:06:21 MainThread @3974053785.py:154] Test reward: 200.0
[12-02 10:06:21 MainThread @3974053785.py:145] Episode 800, Reward Sum 144.0.
[12-02 10:06:22 MainThread @3974053785.py:145] Episode 810, Reward Sum 195.0.
[12-02 10:06:24 MainThread @3974053785.py:145] Episode 820, Reward Sum 174.0.
[12-02 10:06:25 MainThread @3974053785.py:145] Episode 830, Reward Sum 167.0.
[12-02 10:06:26 MainThread @3974053785.py:145] Episode 840, Reward Sum 125.0.
[12-02 10:06:27 MainThread @3974053785.py:145] Episode 850, Reward Sum 62.0.
[12-02 10:06:27 MainThread @3974053785.py:145] Episode 860, Reward Sum 200.0.
[12-02 10:06:29 MainThread @3974053785.py:145] Episode 870, Reward Sum 137.0.
[12-02 10:06:30 MainThread @3974053785.py:145] Episode 880, Reward Sum 200.0.
[12-02 10:06:31 MainThread @3974053785.py:145] Episode 890, Reward Sum 30.0.
[12-02 10:06:32 MainThread @3974053785.py:154] Test reward: 200.0
[12-02 10:06:32 MainThread @3974053785.py:145] Episode 900, Reward Sum 161.0.
[12-02 10:06:33 MainThread @3974053785.py:145] Episode 910, Reward Sum 200.0.
[12-02 10:06:34 MainThread @3974053785.py:145] Episode 920, Reward Sum 194.0.
[12-02 10:06:36 MainThread @3974053785.py:145] Episode 930, Reward Sum 200.0.
[12-02 10:06:37 MainThread @3974053785.py:145] Episode 940, Reward Sum 200.0.
[12-02 10:06:38 MainThread @3974053785.py:145] Episode 950, Reward Sum 200.0.
[12-02 10:06:39 MainThread @3974053785.py:145] Episode 960, Reward Sum 200.0.
[12-02 10:06:40 MainThread @3974053785.py:145] Episode 970, Reward Sum 200.0.
[12-02 10:06:41 MainThread @3974053785.py:145] Episode 980, Reward Sum 200.0.
[12-02 10:06:42 MainThread @3974053785.py:145] Episode 990, Reward Sum 193.0.
[12-02 10:06:44 MainThread @3974053785.py:154] Test reward: 200.0