DQN,基于Q-learning,结合了神经网络,不再使用Q表格来存储Q值,而是用神经网络拟合的方式,可以大大减少内存的占用,同时也更加省时。
DQN在Q-learning的基础上,又添加了经验池回放和固定Q网络两个新的技巧
LEARN_FREQ = 5
MEMORY_SIZE = 20000
MEMORY_WARMUP_SIZE = 200
BATCH_SIZE = 32
LEARNING_RATE = 0.001
GAMMA = 0.99
class Model(parl.Model):
def __init__(self, obs_dim, act_dim):
super().__init__()
self.fc1 = nn.Linear(obs_dim, 128)
self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(128, act_dim)
def forward(self, x):
h1 = F.relu(self.fc1(x))
h2 = F.relu(self.fc2(h1))
Q = self.fc3(h2)
return Q
class Agent(parl.Agent):
def __init__(self, algthrim, act_dim, e_greed=None, e_greed_decrement=None):
super().__init__(algorithm)
self.e_greed = e_greed
self.e_greed_decrement = e_greed_decrement
self.act_dim = act_dim
self.global_steps = 0
self.update_target_steps = 200
def sample(self, obs):
sample1 = np.random.random()
if sample1 > self.e_greed:
act = self.predict(obs)
else:
act = np.random.randint(self.act_dim)
self.e_greed = max(0.01, self.e_greed - self.e_greed_decrement)
return act
def predict(self, obs):
obs = paddle.to_tensor(obs, dtype='float32')
predQ = self.alg.predict(obs)
act = paddle.argmax(predQ).numpy()[0]
return act
def learn(self, obs, act, reward, next_obs, terminal):
if self.global_steps % self.update_target_steps == 0:
self.alg.sync_target()
self.global_steps += 1
act = np.expand_dims(act, axis=-1)
reward = np.expand_dims(act, axis=-1)
terminal = np.expand_dims(terminal, axis=-1)
obs = paddle.to_tensor(obs, dtype='float32')
act = paddle.to_tensor(act, dtype='int32')
reward = paddle.to_tensor(reward, dtype='float32')
next_obs = paddle.to_tensor(next_obs, dtype='float32')
terminal = paddle.to_tensor(terminal, dtype='float32')
loss = self.alg.learn(obs, act, reward, next_obs, terminal)
return loss.numpy()[0]
class DQN(parl.Algorithm):
def __init__(self, model, gamma=None, lr=None):
self.gamma = gamma
self.model = model
self.target_model = copy.deepcopy(model)
self.optimizer = paddle.optimizer.Adam(learning_rate=lr, parameters=model.parameters())
self.mse_loss = paddle.nn.MSELoss(reduction='mean')
def predict(self, obs):
return self.model(obs)
def learn(self, obs, act, reward, next_obs, terminal):
pred_value = self.model(obs)
act_dim = pred_value.shape[-1]
act = paddle.squeeze(act, axis=-1)
act = F.one_hot(act, num_classes=act_dim)
pred_value = paddle.multiply(pred_value, act)
pred_value = paddle.sum(pred_value, axis=1, keepdim=True)
with paddle.no_grad():
max_v = self.target_model(next_obs).max(1, keepdim=True)
target = reward + (1 - terminal) * self.gamma * max_v
loss = self.mse_loss(pred_value, target)
self.optimizer.clear_grad()
loss.backward()
self.optimizer.step()
return loss
def sync_target(self):
self.model.sync_weights_to(self.target_model)
class ReplayMemory(object):
def __init__(self, max_size):
self.buffer1 = collections.deque(maxlen=max_size)
def append(self, exp):
self.buffer1.append(exp)
def sample(self, batch_size):
mini_batch = random.sample(self.buffer1, batch_size)
obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = [], [], [], [], []
for experience in mini_batch:
s, a, r, s_p, done = experience
obs_batch.append(s)
action_batch.append(a)
reward_batch.append(r)
next_obs_batch.append(s_p)
done_batch.append(done)
return np.array(obs_batch, dtype='float32'), \
np.array(action_batch, dtype='float32'),\
np.array(reward_batch, dtype='float32'),\
np.array(next_obs_batch, dtype='float32'),\
np.array(done_batch, dtype='float32')
def __len__(self):
return len(self.buffer1)
def run_episode(env, agent, rpm):
total_reward = 0
obs = env.reset()
step = 0
while True:
step += 1
action = agent.sample(obs)
next_obs, reward, done, _ = env.step(action)
rpm.append((obs, action, reward, next_obs, done))
if(len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0):
(batch_obs, batch_action, batch_reward, batch_next_obs, batch_done) = rpm.sample(BATCH_SIZE)
train_loss = agent.learn(batch_obs, batch_action, batch_reward, batch_next_obs, batch_done)
total_reward += reward
obs = next_obs
if done:
break
return total_reward
def evaluate(env, agent, render=False):
eval_reward = []
for i in range(5):
obs = env.reset()
episode_reward = 0
while True:
action = agent.predict(obs)
obs, reward, done, _ = env.step(action)
episode_reward += reward
if render:
env.render()
# img.set_data(env.render(mode='rgb_array'))
# display.display(plt.gcf())
# display.clear_output(wait=True)
if done:
break
eval_reward.append(episode_reward)
return np.mean(eval_reward)
env = gym.make('CartPole-v0')
# img = plt.imshow(env.render(mode='rgb_array'))
action_dim = env.action_space.n
obs_shape = env.observation_space.shape
rpm = ReplayMemory(MEMORY_SIZE)
model = Model(obs_shape[0], action_dim)
algorithm = DQN(model, gamma=GAMMA, lr=LEARNING_RATE)
agent = Agent(
algorithm,
act_dim=action_dim,
e_greed=0.1,
e_greed_decrement=1e-6
)
while len(rpm) < MEMORY_WARMUP_SIZE:
run_episode(env, agent, rpm)
max_episode = 2000
episode = 0
while episode < max_episode:
for i in range(0, 50):
total_reward = run_episode(env, agent, rpm)
episode += 1
eval_reward = evaluate(env, agent, render=False)
logger.info('episode{}, e_greed:{}, test_reward:{}'.format(episode, agent.e_greed, eval_reward))
save_path = './dqn_model.ckpt'
agent.save(save_path)
输出:
训练完后,小车就可以愉快的冲上山坡了
[11-29 23:19:36 MainThread @machine_info.py:88] nvidia-smi -L found gpu count: 1
[11-29 23:19:37 MainThread @3050295955.py:180] episode50, e_greed:0.09900199999999901, test_reward:9.8
[11-29 23:19:38 MainThread @3050295955.py:180] episode100, e_greed:0.09847399999999848, test_reward:9.2
[11-29 23:19:38 MainThread @3050295955.py:180] episode150, e_greed:0.09797599999999798, test_reward:9.0
[11-29 23:19:39 MainThread @3050295955.py:180] episode200, e_greed:0.09748699999999749, test_reward:9.4
[11-29 23:19:39 MainThread @3050295955.py:180] episode250, e_greed:0.09692799999999693, test_reward:10.8
[11-29 23:19:40 MainThread @3050295955.py:180] episode300, e_greed:0.09634299999999635, test_reward:10.0
[11-29 23:19:41 MainThread @3050295955.py:180] episode350, e_greed:0.09570499999999571, test_reward:46.0
[11-29 23:19:43 MainThread @3050295955.py:180] episode400, e_greed:0.09323999999999324, test_reward:10.6
[11-29 23:19:53 MainThread @3050295955.py:180] episode450, e_greed:0.08461999999998462, test_reward:200.0
[11-29 23:20:04 MainThread @3050295955.py:180] episode500, e_greed:0.07530999999997531, test_reward:197.8
[11-29 23:20:15 MainThread @3050295955.py:180] episode550, e_greed:0.06622799999996623, test_reward:145.4
[11-29 23:20:26 MainThread @3050295955.py:180] episode600, e_greed:0.05703799999995704, test_reward:135.2
[11-29 23:20:36 MainThread @3050295955.py:180] episode650, e_greed:0.048734999999948736, test_reward:155.6
[11-29 23:20:47 MainThread @3050295955.py:180] episode700, e_greed:0.0401009999999401, test_reward:147.6
[11-29 23:20:56 MainThread @3050295955.py:180] episode750, e_greed:0.03288899999993289, test_reward:200.0
[11-29 23:21:05 MainThread @3050295955.py:180] episode800, e_greed:0.025805999999925805, test_reward:180.0
[11-29 23:21:14 MainThread @3050295955.py:180] episode850, e_greed:0.018424999999918423, test_reward:121.6
[11-29 23:21:24 MainThread @3050295955.py:180] episode900, e_greed:0.01, test_reward:135.0
[11-29 23:21:35 MainThread @3050295955.py:180] episode950, e_greed:0.01, test_reward:162.2
[11-29 23:21:46 MainThread @3050295955.py:180] episode1000, e_greed:0.01, test_reward:158.2
[11-29 23:21:55 MainThread @3050295955.py:180] episode1050, e_greed:0.01, test_reward:200.0
[11-29 23:22:07 MainThread @3050295955.py:180] episode1100, e_greed:0.01, test_reward:200.0
[11-29 23:22:19 MainThread @3050295955.py:180] episode1150, e_greed:0.01, test_reward:117.0
[11-29 23:22:31 MainThread @3050295955.py:180] episode1200, e_greed:0.01, test_reward:200.0
[11-29 23:22:43 MainThread @3050295955.py:180] episode1250, e_greed:0.01, test_reward:200.0
[11-29 23:22:55 MainThread @3050295955.py:180] episode1300, e_greed:0.01, test_reward:112.8
[11-29 23:23:07 MainThread @3050295955.py:180] episode1350, e_greed:0.01, test_reward:200.0
[11-29 23:23:19 MainThread @3050295955.py:180] episode1400, e_greed:0.01, test_reward:200.0
[11-29 23:23:31 MainThread @3050295955.py:180] episode1450, e_greed:0.01, test_reward:200.0
[11-29 23:23:43 MainThread @3050295955.py:180] episode1500, e_greed:0.01, test_reward:200.0
[11-29 23:23:53 MainThread @3050295955.py:180] episode1550, e_greed:0.01, test_reward:200.0
[11-29 23:24:04 MainThread @3050295955.py:180] episode1600, e_greed:0.01, test_reward:155.6
[11-29 23:24:15 MainThread @3050295955.py:180] episode1650, e_greed:0.01, test_reward:102.6
[11-29 23:24:23 MainThread @3050295955.py:180] episode1700, e_greed:0.01, test_reward:121.6
[11-29 23:24:34 MainThread @3050295955.py:180] episode1750, e_greed:0.01, test_reward:120.6
[11-29 23:24:44 MainThread @3050295955.py:180] episode1800, e_greed:0.01, test_reward:200.0
[11-29 23:24:54 MainThread @3050295955.py:180] episode1850, e_greed:0.01, test_reward:116.4
[11-29 23:25:07 MainThread @3050295955.py:180] episode1900, e_greed:0.01, test_reward:200.0
[11-29 23:25:18 MainThread @3050295955.py:180] episode1950, e_greed:0.01, test_reward:200.0
[11-29 23:25:31 MainThread @3050295955.py:180] episode2000, e_greed:0.01, test_reward:200.0