paddle2.2.0:DQN算法训练cartpole游戏

       DQN,基于Q-learning,结合了神经网络,不再使用Q表格来存储Q值,而是用神经网络拟合的方式,可以大大减少内存的占用,同时也更加省时。

        DQN在Q-learning的基础上,又添加了经验池回放和固定Q网络两个新的技巧

LEARN_FREQ = 5
MEMORY_SIZE = 20000
MEMORY_WARMUP_SIZE = 200
BATCH_SIZE = 32
LEARNING_RATE = 0.001
GAMMA = 0.99

class Model(parl.Model):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.fc1 = nn.Linear(obs_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, act_dim)

    def forward(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        Q = self.fc3(h2)
        return Q

class Agent(parl.Agent):
    def __init__(self, algthrim, act_dim, e_greed=None, e_greed_decrement=None):
        super().__init__(algorithm)
        self.e_greed = e_greed
        self.e_greed_decrement = e_greed_decrement
        self.act_dim = act_dim
        self.global_steps = 0
        self.update_target_steps = 200

    def sample(self, obs):
        sample1 = np.random.random()
        if sample1 > self.e_greed:
            act = self.predict(obs)
        else:
            act = np.random.randint(self.act_dim)
        self.e_greed = max(0.01, self.e_greed - self.e_greed_decrement)
        return act

    def predict(self, obs):
        obs = paddle.to_tensor(obs, dtype='float32')
        predQ = self.alg.predict(obs)
        act = paddle.argmax(predQ).numpy()[0]
        return act

    def learn(self, obs, act, reward, next_obs, terminal):
        if self.global_steps % self.update_target_steps == 0:
            self.alg.sync_target()
        self.global_steps += 1

        act = np.expand_dims(act, axis=-1)
        reward = np.expand_dims(act, axis=-1)
        terminal = np.expand_dims(terminal, axis=-1)

        obs = paddle.to_tensor(obs, dtype='float32')
        act = paddle.to_tensor(act, dtype='int32')
        reward = paddle.to_tensor(reward, dtype='float32')
        next_obs = paddle.to_tensor(next_obs, dtype='float32')
        terminal = paddle.to_tensor(terminal, dtype='float32')

        loss = self.alg.learn(obs, act, reward, next_obs, terminal)
        return loss.numpy()[0]

class DQN(parl.Algorithm):
    def __init__(self, model, gamma=None, lr=None):
        self.gamma = gamma
        self.model = model
        self.target_model = copy.deepcopy(model)
        self.optimizer = paddle.optimizer.Adam(learning_rate=lr, parameters=model.parameters())
        self.mse_loss = paddle.nn.MSELoss(reduction='mean')

    def predict(self, obs):
        return self.model(obs)

    def learn(self, obs, act, reward, next_obs, terminal):
        pred_value = self.model(obs)
        act_dim = pred_value.shape[-1]
        act = paddle.squeeze(act, axis=-1)
        act = F.one_hot(act, num_classes=act_dim)
        pred_value = paddle.multiply(pred_value, act)
        pred_value = paddle.sum(pred_value, axis=1, keepdim=True)

        with paddle.no_grad():
            max_v = self.target_model(next_obs).max(1, keepdim=True)
            target = reward + (1 - terminal) * self.gamma * max_v
        
        loss = self.mse_loss(pred_value, target)
        self.optimizer.clear_grad()
        loss.backward()
        self.optimizer.step()
        return loss

    def sync_target(self):
        self.model.sync_weights_to(self.target_model)

class ReplayMemory(object):
    def __init__(self, max_size):
        self.buffer1 = collections.deque(maxlen=max_size)

    def append(self, exp):
        self.buffer1.append(exp)

    def sample(self, batch_size):
        mini_batch = random.sample(self.buffer1, batch_size)
        obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = [], [], [], [], []
        for experience in mini_batch:
            s, a, r, s_p, done = experience
            obs_batch.append(s)
            action_batch.append(a)
            reward_batch.append(r)
            next_obs_batch.append(s_p)
            done_batch.append(done)

        return np.array(obs_batch, dtype='float32'), \
                np.array(action_batch, dtype='float32'),\
                np.array(reward_batch, dtype='float32'),\
                np.array(next_obs_batch, dtype='float32'),\
                np.array(done_batch, dtype='float32')

    def __len__(self):
        return len(self.buffer1)

def run_episode(env, agent, rpm):
    total_reward = 0
    obs = env.reset()
    step = 0
    while True:
        step += 1
        action = agent.sample(obs)
        next_obs, reward, done, _ = env.step(action)
        rpm.append((obs, action, reward, next_obs, done))

        if(len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0):
            (batch_obs, batch_action, batch_reward, batch_next_obs, batch_done) = rpm.sample(BATCH_SIZE)
            train_loss = agent.learn(batch_obs, batch_action, batch_reward, batch_next_obs, batch_done)

        total_reward += reward
        obs = next_obs
        if done:
            break
    return total_reward

def evaluate(env, agent, render=False):
    eval_reward = []
    for i in range(5):
        obs = env.reset()
        episode_reward = 0
        while True:
            action = agent.predict(obs)
            obs, reward, done, _ = env.step(action)
            episode_reward += reward
            if render:
                env.render()
                # img.set_data(env.render(mode='rgb_array'))
                # display.display(plt.gcf())
                # display.clear_output(wait=True)
            if done:
                break
        eval_reward.append(episode_reward)
    return np.mean(eval_reward)

env = gym.make('CartPole-v0')
# img = plt.imshow(env.render(mode='rgb_array'))
action_dim = env.action_space.n
obs_shape = env.observation_space.shape

rpm = ReplayMemory(MEMORY_SIZE)
model  = Model(obs_shape[0], action_dim)
algorithm = DQN(model, gamma=GAMMA, lr=LEARNING_RATE)
agent = Agent(
    algorithm,
    act_dim=action_dim,
    e_greed=0.1,
    e_greed_decrement=1e-6
)

while len(rpm) < MEMORY_WARMUP_SIZE:
    run_episode(env, agent, rpm)

max_episode = 2000

episode = 0
while episode < max_episode:
    for i in range(0, 50):
        total_reward = run_episode(env, agent, rpm)
        episode += 1
    eval_reward = evaluate(env, agent, render=False)
    logger.info('episode{},  e_greed:{},   test_reward:{}'.format(episode, agent.e_greed, eval_reward))
save_path = './dqn_model.ckpt'
agent.save(save_path)

 输出:

训练完后,小车就可以愉快的冲上山坡了

paddle2.2.0:DQN算法训练cartpole游戏_第1张图片

[11-29 23:19:36 MainThread @machine_info.py:88] nvidia-smi -L found gpu count: 1
[11-29 23:19:37 MainThread @3050295955.py:180] episode50,  e_greed:0.09900199999999901,   test_reward:9.8
[11-29 23:19:38 MainThread @3050295955.py:180] episode100,  e_greed:0.09847399999999848,   test_reward:9.2
[11-29 23:19:38 MainThread @3050295955.py:180] episode150,  e_greed:0.09797599999999798,   test_reward:9.0
[11-29 23:19:39 MainThread @3050295955.py:180] episode200,  e_greed:0.09748699999999749,   test_reward:9.4
[11-29 23:19:39 MainThread @3050295955.py:180] episode250,  e_greed:0.09692799999999693,   test_reward:10.8
[11-29 23:19:40 MainThread @3050295955.py:180] episode300,  e_greed:0.09634299999999635,   test_reward:10.0
[11-29 23:19:41 MainThread @3050295955.py:180] episode350,  e_greed:0.09570499999999571,   test_reward:46.0
[11-29 23:19:43 MainThread @3050295955.py:180] episode400,  e_greed:0.09323999999999324,   test_reward:10.6
[11-29 23:19:53 MainThread @3050295955.py:180] episode450,  e_greed:0.08461999999998462,   test_reward:200.0
[11-29 23:20:04 MainThread @3050295955.py:180] episode500,  e_greed:0.07530999999997531,   test_reward:197.8
[11-29 23:20:15 MainThread @3050295955.py:180] episode550,  e_greed:0.06622799999996623,   test_reward:145.4
[11-29 23:20:26 MainThread @3050295955.py:180] episode600,  e_greed:0.05703799999995704,   test_reward:135.2
[11-29 23:20:36 MainThread @3050295955.py:180] episode650,  e_greed:0.048734999999948736,   test_reward:155.6
[11-29 23:20:47 MainThread @3050295955.py:180] episode700,  e_greed:0.0401009999999401,   test_reward:147.6
[11-29 23:20:56 MainThread @3050295955.py:180] episode750,  e_greed:0.03288899999993289,   test_reward:200.0
[11-29 23:21:05 MainThread @3050295955.py:180] episode800,  e_greed:0.025805999999925805,   test_reward:180.0
[11-29 23:21:14 MainThread @3050295955.py:180] episode850,  e_greed:0.018424999999918423,   test_reward:121.6
[11-29 23:21:24 MainThread @3050295955.py:180] episode900,  e_greed:0.01,   test_reward:135.0
[11-29 23:21:35 MainThread @3050295955.py:180] episode950,  e_greed:0.01,   test_reward:162.2
[11-29 23:21:46 MainThread @3050295955.py:180] episode1000,  e_greed:0.01,   test_reward:158.2
[11-29 23:21:55 MainThread @3050295955.py:180] episode1050,  e_greed:0.01,   test_reward:200.0
[11-29 23:22:07 MainThread @3050295955.py:180] episode1100,  e_greed:0.01,   test_reward:200.0
[11-29 23:22:19 MainThread @3050295955.py:180] episode1150,  e_greed:0.01,   test_reward:117.0
[11-29 23:22:31 MainThread @3050295955.py:180] episode1200,  e_greed:0.01,   test_reward:200.0
[11-29 23:22:43 MainThread @3050295955.py:180] episode1250,  e_greed:0.01,   test_reward:200.0
[11-29 23:22:55 MainThread @3050295955.py:180] episode1300,  e_greed:0.01,   test_reward:112.8
[11-29 23:23:07 MainThread @3050295955.py:180] episode1350,  e_greed:0.01,   test_reward:200.0
[11-29 23:23:19 MainThread @3050295955.py:180] episode1400,  e_greed:0.01,   test_reward:200.0
[11-29 23:23:31 MainThread @3050295955.py:180] episode1450,  e_greed:0.01,   test_reward:200.0
[11-29 23:23:43 MainThread @3050295955.py:180] episode1500,  e_greed:0.01,   test_reward:200.0
[11-29 23:23:53 MainThread @3050295955.py:180] episode1550,  e_greed:0.01,   test_reward:200.0
[11-29 23:24:04 MainThread @3050295955.py:180] episode1600,  e_greed:0.01,   test_reward:155.6
[11-29 23:24:15 MainThread @3050295955.py:180] episode1650,  e_greed:0.01,   test_reward:102.6
[11-29 23:24:23 MainThread @3050295955.py:180] episode1700,  e_greed:0.01,   test_reward:121.6
[11-29 23:24:34 MainThread @3050295955.py:180] episode1750,  e_greed:0.01,   test_reward:120.6
[11-29 23:24:44 MainThread @3050295955.py:180] episode1800,  e_greed:0.01,   test_reward:200.0
[11-29 23:24:54 MainThread @3050295955.py:180] episode1850,  e_greed:0.01,   test_reward:116.4
[11-29 23:25:07 MainThread @3050295955.py:180] episode1900,  e_greed:0.01,   test_reward:200.0
[11-29 23:25:18 MainThread @3050295955.py:180] episode1950,  e_greed:0.01,   test_reward:200.0
[11-29 23:25:31 MainThread @3050295955.py:180] episode2000,  e_greed:0.01,   test_reward:200.0

你可能感兴趣的:(强化学习,python,paddle,深度学习)