深度强化学习-A3C算法

论文地址:https://arxiv.org/pdf/1602.01783v1.pdf

        A3C(异步优势演员评论家)算法,设计该算法的目的是找到能够可靠的训练深度神经网络,且不需要大量资源的RL算法。

        在DQN算法中,为了方便收敛使用了经验回放的技巧。A3C更进一步,并克服了一些经验回放的问题。如,回放池经验数据相关性太强,用于训练的时候效果很可能不佳。举个例子,我们学习下棋,总是和同一个人下,期望能提高棋艺。这当然是可行的,但是到一定程度就再难提高了,此时最好的方法是另寻高手切磋。

        A3C也是基于此思路,利用多线程的方法,同时在多个线程里面分别和环境进行交互学习,每个线程都把学习的经验汇总起来,整理保存在一个公共的地方。并且,定期从公共的地方把其他线程的学习经验拿回来,指导自己和环境的学习交互。

        A3C解决了Actor-Critic难以收敛的问题,更重要的是,提供了一种通用的异步的并发的强化学习框架,也就是说,这个并发框架不光可以用于A3C,还可以用于其他的强化学习算法。这是A3C最大的贡献。目前,已经有基于GPUA3C框架,这样A3C的框架训练速度就更快了。除了A3C, DDPG算法也可以改善Actor-Critic难收敛的问题。它使用了Nature DQNDDQN类似的思想,用两个Actor网络,两个Critic网络,一共4个神经网络来迭代更新模型参数。

基于经验回放有以下几个缺点:

  1. 每次实际交互使用需要更多的内存和计算
  2. 需要能够从旧策略生成的数据进行更新的异策略学些算法

A3C的优点:

  1. 可以不需要依赖GPU的算力,仅用多核CPU
  2. 在atari 2600环境中的许多游戏都取得了更好的结果
  3. 相较以前使用GPU的算法,A3C计算时间短,资源消耗少
  4. 可应用于离散和连续动作空间的任务 

算法流程

        算法伪代码 

深度强化学习-A3C算法_第1张图片

代码实现(部分)

        model:

class ActorCritic(torch.nn.Module):
    def __init__(self, num_inputs, action_space):
        super(ActorCritic, self).__init__()
        self.conv1 = nn.Conv2d(num_inputs, 32, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(32, 32, 3, stride=2, padding=1)

        self.lstm = nn.LSTMCell(32 * 3 * 3, 256)

        num_outputs = action_space.n
        self.critic_linear = nn.Linear(256, 1)
        self.actor_linear = nn.Linear(256, num_outputs)

        self.apply(weights_init)
        self.actor_linear.weight.data = normalized_columns_initializer(
            self.actor_linear.weight.data, 0.01)
        self.actor_linear.bias.data.fill_(0)
        self.critic_linear.weight.data = normalized_columns_initializer(
            self.critic_linear.weight.data, 1.0)
        self.critic_linear.bias.data.fill_(0)

        self.lstm.bias_ih.data.fill_(0)
        self.lstm.bias_hh.data.fill_(0)

        self.train()

    def forward(self, inputs):
        inputs, (hx, cx) = inputs
        x = F.elu(self.conv1(inputs))
        x = F.elu(self.conv2(x))
        x = F.elu(self.conv3(x))
        x = F.elu(self.conv4(x))

        x = x.view(-1, 32 * 3 * 3)
        hx, cx = self.lstm(x, (hx, cx))
        x = hx

        return self.critic_linear(x), self.actor_linear(x), (hx, cx)

        这里使用的LSTM的版本,当然也可以替换成CNN版本

        train:

def train(rank, args, shared_model, counter, lock, optimizer=None):
    torch.manual_seed(args.seed + rank)

    env = create_atari_env(args.env_name)
    env.seed(args.seed + rank)

    model = ActorCritic(env.observation_space.shape[0], env.action_space)

    if optimizer is None:
        optimizer = optim.Adam(shared_model.parameters(), lr=args.lr)

    model.train()

    state = env.reset()
    state = torch.from_numpy(state)
    done = True

    episode_length = 0
    while True:
        # Sync with the shared model
        model.load_state_dict(shared_model.state_dict())
        if done:
            cx = torch.zeros(1, 256)
            hx = torch.zeros(1, 256)
        else:
            cx = cx.detach()
            hx = hx.detach()

        values = []
        log_probs = []
        rewards = []
        entropies = []

        env.render()
        for step in range(args.num_steps):
            episode_length += 1
            value, logit, (hx, cx) = model((state.unsqueeze(0),
                                            (hx, cx)))
            prob = F.softmax(logit, dim=-1)
            log_prob = F.log_softmax(logit, dim=-1)
            entropy = -(log_prob * prob).sum(1, keepdim=True)
            entropies.append(entropy)

            action = prob.multinomial(num_samples=1).detach()
            log_prob = log_prob.gather(1, action)

            state, reward, done, _ = env.step(action.numpy())
            done = done or episode_length >= args.max_episode_length
            reward = max(min(reward, 1), -1)

            with lock:
                counter.value += 1

            if done:
                episode_length = 0
                state = env.reset()

            state = torch.from_numpy(state)
            values.append(value)
            log_probs.append(log_prob)
            rewards.append(reward)

            if done:
                break

        R = torch.zeros(1, 1)
        if not done:
            value, _, _ = model((state.unsqueeze(0), (hx, cx)))
            R = value.detach()

        values.append(R)
        policy_loss = 0
        value_loss = 0
        gae = torch.zeros(1, 1)
        for i in reversed(range(len(rewards))):
            R = args.gamma * R + rewards[i]
            advantage = R - values[i]
            value_loss = value_loss + 0.5 * advantage.pow(2)

            # Generalized Advantage Estimation
            delta_t = rewards[i] + args.gamma * \
                values[i + 1] - values[i]
            gae = gae * args.gamma * args.gae_lambda + delta_t

            policy_loss = policy_loss - \
                log_probs[i] * gae.detach() - args.entropy_coef * entropies[i]

        optimizer.zero_grad()

        (policy_loss + args.value_loss_coef * value_loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

        ensure_shared_grads(model, shared_model)
        optimizer.step()

        main函数:

if __name__ == '__main__':

    args = parser.parse_args()

    torch.manual_seed(args.seed)
    env = create_atari_env(args.env_name)
    shared_model = ActorCritic(
        env.observation_space.shape[0], env.action_space)
    shared_model.share_memory()

    if args.no_shared:
        optimizer = None
    else:
        optimizer = my_optim.SharedAdam(shared_model.parameters(), lr=args.lr)
        optimizer.share_memory()

    processes = []

    counter = mp.Value('i', 0)
    lock = mp.Lock()

    p = mp.Process(target=test, args=(args.num_processes, args, shared_model, counter))
    p.start()
    processes.append(p)

    for rank in range(0, args.num_processes):
        p = mp.Process(target=train, args=(rank, args, shared_model, counter, lock, optimizer))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

算法图解

待更新

你可能感兴趣的:(强化学习)