深度强化学习系列(10): NoisyNet-DQN原理及实现

深度强化学习系列(10): NoisyNet-DQN原理及实现_第1张图片

论文地址: https://arxiv.org/pdf/1706.10295v1.pdf
本篇论文是DeepMind发表于顶会ICLR2018上的论文,第一作者Meire,里面也有熟悉的Mnih等大佬,还是往常的阅读顺序:
深度强化学习系列(10): NoisyNet-DQN原理及实现_第2张图片

本文解决的是强化学习中的“探索问题”(efficient exploration),作者通过给训练网络中添加噪音参数(和梯度更新同时更新网络权重参数),通过权重网络的训练来更新参数,结果表明能够使用额外较小的计算成本,在A3C、DQN、Dueling DQN等算法上实现相对于传统的启发式更优的结果。
深度强化学习系列(10): NoisyNet-DQN原理及实现_第3张图片

1。背景及问题

我们知道,对于探索-利用而言,目前通常采用以下两种方法:

  • epsilon-greedy ϵ \epsilon ϵ- g r e e d y greedy greedy(以超参数形式给出)很有可能会导致智能体采取随机步骤,而不是按照它学到的策略采取行动。 通常的做法是,在训练开始时使该 ϵ \epsilon ϵ- g r e e d y = 1 greedy=1 greedy=1,然后慢慢减小到一个较小的值,例如0.1或0.02。

  • 熵正则化:将策略的熵添加到损失函数中时,在策略梯度方法中使用它,以惩罚我们的模型过于确定其行为。

而常见的启发式搜索在强化学习中的原则是“Optimism in the face of uncertainty”,这种面对不确定性乐观的属性就导致了启发式需要在智能体的性能表现上有理论保证,而这些方法的缺点在于仅限于使用在较小的状态、动作空间比较小或者是线性函数逼近的问题上效果还可以,对于一些问题复杂的函数毕竟问题并不能够很好的进行解决。

本文作者提出了一种NoisyNet,该方法主要包括将高斯噪声添加到网络的最后(完全连接)层。 噪声的参数可以在训练过程中通过模型进行调整,这使智能体可以决定何时以及以什么比例将权重引入不确定性。

2.原理与数学过程

NoisyNet是一个神经网络,其权重和偏置会受到噪音的影响,

一般的,如果将NoisyNet数学表示为 y = f θ ( x ) y=f_{\theta}(x) y=fθ(x)(x表示输入,y表示输出, θ \theta θ表示噪音参数),作者在此处将 θ \theta θ定义为:
θ =  def  μ + Σ ⊙ ε \theta \stackrel{\text { def }}{=} \mu+\Sigma \odot \varepsilon θ= def μ+Σε
其中 ζ =  def  ( μ , Σ ) \zeta \stackrel{\text { def }}{=}(\mu, \Sigma) ζ= def (μ,Σ)定义为是一组可学习参数向量的集合, “ ε \varepsilon ε”是零均值噪声的矢量,具有固定统计量, ⊙ \odot 表示逐元素乘法。那么关于噪音参数的的损失函数我们表示为:
ε : L ˉ ( ζ ) =  def  E [ L ( θ ) ] \varepsilon: \bar{L}(\zeta) \stackrel{\text { def }}{=} \mathbb{E}[L(\theta)] ε:Lˉ(ζ)= def E[L(θ)]
那么接下来的过程就是对 ζ \zeta ζ进行优化。那如何优化呢?

接下来让我们思考这样一个问题
对于一个 p p p 个输入 q q q 个输出来说,数学表示为 y = w x + b y= wx+b y=wx+b,其中可知 w ∈ R q × p w \in \mathbb{R}^{q \times p} wRq×p, x ∈ R p x \in \mathbb{R}^{p} xRp, b ∈ R p b \in \mathbb{R}^{p} bRp,这个理解起来很简单。那么如果给参数中添加噪音呢(也就是给网络结构添加噪音)?下面是作者给出的带有噪音参数的线性层数学表示:
y =  def  ( μ w + σ w ⊙ ε w ) x + μ b + σ b ⊙ ε b y \stackrel{\text { def }}{=}\left(\mu^{w}+\sigma^{w} \odot \varepsilon^{w}\right) x+\mu^{b}+\sigma^{b} \odot \varepsilon^{b} y= def (μw+σwεw)x+μb+σbεb
乍看起来挺复杂,其中$w 等 价 于 等价于 \left(\mu{w}+\sigma{w} \odot \varepsilon{w}\right)$,$b$等价于$(\mu{b}+\sigma^{b} \odot \varepsilon^{b})$,每个参数的维度如下:

μ \mu μ σ \sigma σ ε \varepsilon ε
μ w ∈ R q × p \mu^{w} \in \mathbb{R}^{q \times p} μwRq×p σ w ∈ R q × p \sigma^{w} \in \mathbb{R}^{q \times p} σwRq×p ε w ∈ R q × p \varepsilon^{w} \in \mathbb{R}^{q \times p} εwRq×p
μ b ∈ R q \mu^{b} \in \mathbb{R}^{q} μbRq σ b ∈ R q \sigma^{b} \in \mathbb{R}^{q} σbRq ε b ∈ R q \varepsilon^{b} \in \mathbb{R}^{q} εbRq

其中 ε \varepsilon ε是随机噪音参数,下图是对该过程进行图表示:
深度强化学习系列(10): NoisyNet-DQN原理及实现_第4张图片

其含义如下:

以上是如何引入噪音的问题,在论文中,作者尝试噪音参数引入的两种分布:

  • 独立高斯噪声(Independent Gaussian Noise):噪声层的每个权重都是独立的,并且具有模型自己学习的 μ \mu μ σ \sigma σ。也就是对于任意的 ε i , j w \varepsilon^{w}_{i,j} εi,jw(对应 ε j b \varepsilon^{b}_{j} εjb)和 ε w \varepsilon^{w} εw(对应 ε b \varepsilon^{b} εb)的参数都是来自高斯分布。共 ( p q + q ) (pq+q) (pq+q) 个变量

  • 分解高斯噪声(Factorised Gaussian Noise):包含噪音的输入输出:第一个具有输入p个单位的高斯分布 ε i \varepsilon_{i} εi噪音输入,第二个具有q个单位的高斯噪音输出。共 ( p + q ) (p+q) (p+q) 个变量,其分解如下:
    ε i , j w = f ( ε i ) f ( ε j ) ε j b = f ( ε j ) \begin{aligned} \varepsilon_{i, j}^{w} &=f\left(\varepsilon_{i}\right) f\left(\varepsilon_{j}\right) \\ \varepsilon_{j}^{b} &=f\left(\varepsilon_{j}\right) \end{aligned} εi,jwεjb=f(εi)f(εj)=f(εj)
    这里的 f : f ( x ) = sgn ⁡ ( x ) ∣ x ∣ f: f(x)=\operatorname{sgn}(x) \sqrt{|x|} f:f(x)=sgn(x)x 函数是一个实值函数

∇ L ˉ ( ζ ) = ∇ E [ L ( θ ) ] = E [ ∇ μ , Σ L ( μ + Σ ⊙ ε ) ] \nabla \bar{L}(\zeta)=\nabla \mathbb{E}[L(\theta)]=\mathbb{E}\left[\nabla_{\mu, \Sigma} L(\mu+\Sigma \odot \varepsilon)\right] Lˉ(ζ)=E[L(θ)]=E[μ,ΣL(μ+Σε)]
使用蒙特卡罗近似梯度,单步优化如下:
∇ L ˉ ( ζ ) ≈ ∇ μ , Σ L ( μ + Σ ⊙ ξ ) \nabla \bar{L}(\zeta) \approx \nabla_{\mu, \Sigma} L(\mu+\Sigma \odot \xi) Lˉ(ζ)μ,ΣL(μ+Σξ)

3. Deep NoisyNet原理以及初始化过程

注:本文的Noisy是针对于值函数(动作-值函数)的,不是针对策略输出的action的

3.1 各种算法的NoisyNet更新公式

其实数学的更新公式很简单,重新构造优化目标 L ˉ ( ζ ) \bar{L}(\zeta) Lˉ(ζ)和优化参数(在原来的值函数基础上加入对应参数即可)

  • NoisyNet-DQN
    L ˉ ( ζ ) = E [ E ( x , a , r , y ) ∼ D [ r + γ max ⁡ b ∈ A Q ( y , b , ε ′ ; ζ − ) − Q ( x , a , ε ; ζ ) ] 2 ] \bar{L}(\zeta)=\mathbb{E}\left[\mathbb{E}_{(x, a, r, y) \sim D}\left[r+\gamma \max _{b \in A} Q\left(y, b, \varepsilon^{\prime} ; \zeta^{-}\right)-Q(x, a, \varepsilon ; \zeta)\right]^{2}\right] Lˉ(ζ)=E[E(x,a,r,y)D[r+γbAmaxQ(y,b,ε;ζ)Q(x,a,ε;ζ)]2]

  • NoisyNet-DuelingDQN
    L ˉ ( ζ ) = E [ E ( x , a , r , y ) ∼ D [ r + γ Q ( y , b ∗ ( y ) , ε ′ ; ζ − ) − Q ( x , a , ε ; ζ ) ] 2 ]  s.t.  b ∗ ( y ) = arg ⁡ max ⁡ b ∈ A Q ( y , b ( y ) , ε ′ ′ ; ζ ) \begin{aligned} \bar{L}(\zeta) &=\mathbb{E}\left[\mathbb{E}_{(x, a, r, y) \sim D}\left[r+\gamma Q\left(y, b^{*}(y), \varepsilon^{\prime} ; \zeta^{-}\right)-Q(x, a, \varepsilon ; \zeta)\right]^{2}\right] \\ \text { s.t. } \quad b^{*}(y) &=\arg \max _{b \in \mathcal{A}} Q\left(y, b(y), \varepsilon^{\prime \prime} ; \zeta\right) \end{aligned} Lˉ(ζ) s.t. b(y)=E[E(x,a,r,y)D[r+γQ(y,b(y),ε;ζ)Q(x,a,ε;ζ)]2]=argbAmaxQ(y,b(y),ε;ζ)

  • NoisyNet-A3C
    Q ^ i = ∑ j = i k − 1 γ j − i r t + j + γ k − i V ( x t + k ; ζ , ε i ) \hat{Q}_{i}=\sum_{j=i}^{k-1} \gamma^{j-i} r_{t+j}+\gamma^{k-i} V\left(x_{t+k} ; \zeta, \varepsilon_{i}\right) Q^i=j=ik1γjirt+j+γkiV(xt+k;ζ,εi)

3.2 噪音的初始化过程

    1. 对于没有分解的高斯参数来说,每个元素 μ i , j \mu_{i,j} μi,j 的采样来自于独立正态分布 U [ − 3 p , + 3 p ] \mathcal{U}[-\sqrt{\frac{3}{p}},+\sqrt{\frac{3}{p}}] U[p3 ,+p3 ] ,其中 p p p表示神经网络的输入层的输入。
    1. 分解高斯参数来说,采样来自于分布 U [ 1 − p , + 1 p ] \mathcal{U}[\frac{1}{-\sqrt{{p}}},+\frac{1}{\sqrt{{p}}}] U[p 1,+p 1]

参见代码:

# Added by Andrew Liao
# for NoisyNet-DQN (using Factorised Gaussian noise)
# modified from ```dense```function
def noisy_dense(x, size, name, bias=True, activation_fn=tf.identity):

    # the function used in eq.7,8
    def f(x):
        return tf.multiply(tf.sign(x), tf.pow(tf.abs(x), 0.5))
    # Initializer of \mu and \sigma 
    mu_init = tf.random_uniform_initializer(minval=-1*1/np.power(x.get_shape().as_list()[1], 0.5),     
                                                maxval=1*1/np.power(x.get_shape().as_list()[1], 0.5))
    sigma_init = tf.constant_initializer(0.4/np.power(x.get_shape().as_list()[1], 0.5))
    # Sample noise from gaussian
    p = sample_noise([x.get_shape().as_list()[1], 1])
    q = sample_noise([1, size])
    f_p = f(p); f_q = f(q)
    w_epsilon = f_p*f_q; b_epsilon = tf.squeeze(f_q)

    # w = w_mu + w_sigma*w_epsilon
    w_mu = tf.get_variable(name + "/w_mu", [x.get_shape()[1], size], initializer=mu_init)
    w_sigma = tf.get_variable(name + "/w_sigma", [x.get_shape()[1], size], initializer=sigma_init)
    w = w_mu + tf.multiply(w_sigma, w_epsilon)
    ret = tf.matmul(x, w)
    if bias:
        # b = b_mu + b_sigma*b_epsilon
        b_mu = tf.get_variable(name + "/b_mu", [size], initializer=mu_init)
        b_sigma = tf.get_variable(name + "/b_sigma", [size], initializer=sigma_init)
        b = b_mu + tf.multiply(b_sigma, b_epsilon)
        return activation_fn(ret + b)
    else:
        return activation_fn(ret)

4.算法伪代码:

深度强化学习系列(10): NoisyNet-DQN原理及实现_第5张图片
深度强化学习系列(10): NoisyNet-DQN原理及实现_第6张图片

5. 实验结果

深度强化学习系列(10): NoisyNet-DQN原理及实现_第7张图片
深度强化学习系列(10): NoisyNet-DQN原理及实现_第8张图片

6.算法实现(仅在部分Atari游戏中使用)

本部分代码包含两种算法 NoisyNet-DQNNoisyNEt-A3C

(1)NoisyNet-DQN

# code source: https://github.com/wenh123/NoisyNet-DQN/blob/master/train.py
import argparse
import gym
import numpy as np
import os
import tensorflow as tf
import tempfile
import time

import baselines.common.tf_util as U

from baselines import logger
from baselines import deepq
from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
from baselines.common.misc_util import (
    boolean_flag,
    pickle_load,
    pretty_eta,
    relatively_safe_pickle_dump,
    set_global_seeds,
    RunningAvg,
    SimpleMonitor
)
from baselines.common.schedules import LinearSchedule, PiecewiseSchedule
# when updating this to non-deperecated ones, it is important to
# copy over LazyFrames
from baselines.common.atari_wrappers_deprecated import wrap_dqn
from baselines.common.azure_utils import Container
from model import model, dueling_model
from statistics import statistics

def parse_args():
    parser = argparse.ArgumentParser("DQN experiments for Atari games")
    # Environment
    parser.add_argument("--env", type=str, default="Pong", help="name of the game")
    parser.add_argument("--seed", type=int, default=42, help="which seed to use")
    # Core DQN parameters
    parser.add_argument("--replay-buffer-size", type=int, default=int(1e6), help="replay buffer size")
    parser.add_argument("--lr", type=float, default=1e-4, help="learning rate for Adam optimizer")
    parser.add_argument("--num-steps", type=int, default=int(2e8), help="total number of steps to run the environment for")
    parser.add_argument("--batch-size", type=int, default=32, help="number of transitions to optimize at the same time")
    parser.add_argument("--learning-freq", type=int, default=4, help="number of iterations between every optimization step")
    parser.add_argument("--target-update-freq", type=int, default=40000, help="number of iterations between every target network update")
    # Bells and whistles
    boolean_flag(parser, "noisy", default=False, help="whether or not to NoisyNetwork")
    boolean_flag(parser, "double-q", default=True, help="whether or not to use double q learning")
    boolean_flag(parser, "dueling", default=False, help="whether or not to use dueling model")
    boolean_flag(parser, "prioritized", default=False, help="whether or not to use prioritized replay buffer")
    parser.add_argument("--prioritized-alpha", type=float, default=0.6, help="alpha parameter for prioritized replay buffer")
    parser.add_argument("--prioritized-beta0", type=float, default=0.4, help="initial value of beta parameters for prioritized replay")
    parser.add_argument("--prioritized-eps", type=float, default=1e-6, help="eps parameter for prioritized replay buffer")
    # Checkpointing
    parser.add_argument("--save-dir", type=str, default=None, required=True, help="directory in which training state and model should be saved.")
    parser.add_argument("--save-azure-container", type=str, default=None,
                        help="It present data will saved/loaded from Azure. Should be in format ACCOUNT_NAME:ACCOUNT_KEY:CONTAINER")
    parser.add_argument("--save-freq", type=int, default=1e6, help="save model once every time this many iterations are completed")
    boolean_flag(parser, "load-on-start", default=True, help="if true and model was previously saved then training will be resumed")
    return parser.parse_args()


def make_env(game_name):
    env = gym.make(game_name + "NoFrameskip-v4")
    monitored_env = SimpleMonitor(env)  # puts rewards and number of steps in info, before environment is wrapped
    env = wrap_dqn(monitored_env)  # applies a bunch of modification to simplify the observation space (downsample, make b/w)
    return env, monitored_env


def maybe_save_model(savedir, container, state):
    """This function checkpoints the model and state of the training algorithm."""
    if savedir is None:
        return
    start_time = time.time()
    model_dir = "model-{}".format(state["num_iters"])
    U.save_state(os.path.join(savedir, model_dir, "saved"))
    if container is not None:
        container.put(os.path.join(savedir, model_dir), model_dir)
    relatively_safe_pickle_dump(state, os.path.join(savedir, 'training_state.pkl.zip'), compression=True)
    if container is not None:
        container.put(os.path.join(savedir, 'training_state.pkl.zip'), 'training_state.pkl.zip')
    relatively_safe_pickle_dump(state["monitor_state"], os.path.join(savedir, 'monitor_state.pkl'))
    if container is not None:
        container.put(os.path.join(savedir, 'monitor_state.pkl'), 'monitor_state.pkl')
    logger.log("Saved model in {} seconds\n".format(time.time() - start_time))


def maybe_load_model(savedir, container):
    """Load model if present at the specified path."""
    if savedir is None:
        return

    state_path = os.path.join(os.path.join(savedir, 'training_state.pkl.zip'))
    if container is not None:
        logger.log("Attempting to download model from Azure")
        found_model = container.get(savedir, 'training_state.pkl.zip')
    else:
        found_model = os.path.exists(state_path)
    if found_model:
        state = pickle_load(state_path, compression=True)
        model_dir = "model-{}".format(state["num_iters"])
        if container is not None:
            container.get(savedir, model_dir)
        U.load_state(os.path.join(savedir, model_dir, "saved"))
        logger.log("Loaded models checkpoint at {} iterations".format(state["num_iters"]))
        return state


if __name__ == '__main__':
    args = parse_args()
    # Parse savedir and azure container.
    savedir = args.save_dir
    if args.save_azure_container is not None:
        account_name, account_key, container_name = args.save_azure_container.split(":")
        container = Container(account_name=account_name,
                              account_key=account_key,
                              container_name=container_name,
                              maybe_create=True)
        if savedir is None:
            # Careful! This will not get cleaned up. Docker spoils the developers.
            savedir = tempfile.TemporaryDirectory().name
    else:
        container = None
    # Create and seed the env.
    env, monitored_env = make_env(args.env)
    if args.seed > 0:
        set_global_seeds(args.seed)
        env.unwrapped.seed(args.seed)

    with U.make_session(4) as sess:
        # Create training graph and replay buffer
        act, train, update_target, debug = deepq.build_train(
            make_obs_ph=lambda name: U.Uint8Input(env.observation_space.shape, name=name),
            q_func=dueling_model if args.dueling else model,
            num_actions=env.action_space.n,
            optimizer=tf.train.AdamOptimizer(learning_rate=args.lr, epsilon=1e-4),
            gamma=0.99,
            grad_norm_clipping=10,
            double_q=args.double_q,
            noisy=args.noisy,
        )
        approximate_num_iters = args.num_steps / 4
        exploration = PiecewiseSchedule([
            (0, 1.0),
            (approximate_num_iters / 50, 0.1),
            (approximate_num_iters / 5, 0.01)
        ], outside_value=0.01)

        if args.prioritized:
            replay_buffer = PrioritizedReplayBuffer(args.replay_buffer_size, args.prioritized_alpha)
            beta_schedule = LinearSchedule(approximate_num_iters, initial_p=args.prioritized_beta0, final_p=1.0)
        else:
            replay_buffer = ReplayBuffer(args.replay_buffer_size)

        U.initialize()
        update_target()
        num_iters = 0

        # Load the model
        state = maybe_load_model(savedir, container)
        if state is not None:
            num_iters, replay_buffer = state["num_iters"], state["replay_buffer"],
            monitored_env.set_state(state["monitor_state"])

        start_time, start_steps = None, None
        steps_per_iter = RunningAvg(0.999)
        iteration_time_est = RunningAvg(0.999)
        obs = env.reset()
        # Record the mean of the \sigma
        sigma_name_list = []
        sigma_list = []
        for param in tf.trainable_variables():
            # only record the \sigma in the action network
            if 'sigma' in param.name and 'deepq/q_func/action_value' in param.name:
                summary_name = param.name.replace('deepq/q_func/action_value/', '').replace('/', '.').split(':')[0]
                sigma_name_list.append(summary_name)
                sigma_list.append(tf.reduce_mean(tf.abs(param)))
        f_mean_sigma = U.function(inputs=[], outputs=sigma_list)
        # Statistics
        writer = tf.summary.FileWriter(savedir, sess.graph)
        im_stats = statistics(scalar_keys=['action', 'im_reward', 'td_errors', 'huber_loss']+sigma_name_list)
        ep_stats = statistics(scalar_keys=['ep_reward', 'ep_length'])  
        # Main trianing loop
        ep_length = 0
        while True:
            num_iters += 1
            ep_length += 1
            # Take action and store transition in the replay buffer.
            if args.noisy:
                # greedily choose
                action = act(np.array(obs)[None], stochastic=False)[0]
            else:
                # epsilon greedy
                action = act(np.array(obs)[None], update_eps=exploration.value(num_iters))[0]
            new_obs, rew, done, info = env.step(action)
            replay_buffer.add(obs, action, rew, new_obs, float(done))
            obs = new_obs
            if done:
                obs = env.reset()

            if (num_iters > max(5 * args.batch_size, args.replay_buffer_size // 20) and
                    num_iters % args.learning_freq == 0):
                # Sample a bunch of transitions from replay buffer
                if args.prioritized:
                    experience = replay_buffer.sample(args.batch_size, beta=beta_schedule.value(num_iters))
                    (obses_t, actions, rewards, obses_tp1, dones, weights, batch_idxes) = experience
                else:
                    obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(args.batch_size)
                    weights = np.ones_like(rewards)
                # Minimize the error in Bellman's equation and compute TD-error
                td_errors, huber_loss = train(obses_t, actions, rewards, obses_tp1, dones, weights)
                # Update the priorities in the replay buffer
                if args.prioritized:
                    new_priorities = np.abs(td_errors) + args.prioritized_eps
                    replay_buffer.update_priorities(batch_idxes, new_priorities)
                # Write summary
                mean_sigma = f_mean_sigma()
                im_stats.add_all_summary(writer, [action, rew, np.mean(td_errors), np.mean(huber_loss)]+mean_sigma, num_iters)

            # Update target network.
            if num_iters % args.target_update_freq == 0:
                update_target()

            if start_time is not None:
                steps_per_iter.update(info['steps'] - start_steps)
                iteration_time_est.update(time.time() - start_time)
            start_time, start_steps = time.time(), info["steps"]

            # Save the model and training state.
            if num_iters > 0 and (num_iters % args.save_freq == 0 or info["steps"] > args.num_steps):
                maybe_save_model(savedir, container, {
                    'replay_buffer': replay_buffer,
                    'num_iters': num_iters,
                    'monitor_state': monitored_env.get_state()
                })

            if info["steps"] > args.num_steps:
                break

            if done:
                steps_left = args.num_steps - info["steps"]
                completion = np.round(info["steps"] / args.num_steps, 1)
                mean_ep_reward = np.mean(info["rewards"][-100:])
                logger.record_tabular("% completion", completion)
                logger.record_tabular("steps", info["steps"])
                logger.record_tabular("iters", num_iters)
                logger.record_tabular("episodes", len(info["rewards"]))
                logger.record_tabular("reward (100 epi mean)", np.mean(info["rewards"][-100:]))
                if not args.noisy:
                    logger.record_tabular("exploration", exploration.value(num_iters))
                if args.prioritized:
                    logger.record_tabular("max priority", replay_buffer._max_priority)
                fps_estimate = (float(steps_per_iter) / (float(iteration_time_est) + 1e-6)
                                if steps_per_iter._value is not None else "calculating...")
                logger.dump_tabular()
                logger.log()
                logger.log("ETA: " + pretty_eta(int(steps_left / fps_estimate)))
                logger.log()
                # add summary for one episode
                ep_stats.add_all_summary(writer, [mean_ep_reward, ep_length], num_iters)
                ep_length = 0

(2)NoisyNet-A3C

# using Pytorch
# code source: https://github.com/Kaixhin/NoisyNet-A3C
import gym
import torch
from torch import nn
from torch.autograd import Variable

from model import ActorCritic
from utils import state_to_tensor


# Transfers gradients from thread-specific model to shared model
def _transfer_grads_to_shared_model(model, shared_model):
  for param, shared_param in zip(model.parameters(), shared_model.parameters()):
    if shared_param.grad is not None:
      return
    shared_param._grad = param.grad


# Adjusts learning rate
def _adjust_learning_rate(optimiser, lr):
  for param_group in optimiser.param_groups:
    param_group['lr'] = lr


def train(rank, args, T, shared_model, optimiser):
  torch.manual_seed(args.seed + rank)

  env = gym.make(args.env)
  env.seed(args.seed + rank)
  model = ActorCritic(env.observation_space, env.action_space, args.hidden_size, args.sigma_init, args.no_noise)
  model.train()

  t = 1  # Thread step counter
  done = True  # Start new episode

  while T.value() <= args.T_max:
    # Sync with shared model at least every t_max steps
    model.load_state_dict(shared_model.state_dict())
    # Get starting timestep
    t_start = t

    # Reset or pass on hidden state
    if done:
      hx = Variable(torch.zeros(1, args.hidden_size))
      cx = Variable(torch.zeros(1, args.hidden_size))
      # Reset environment and done flag
      state = state_to_tensor(env.reset())
      done, episode_length = False, 0
    else:
      # Perform truncated backpropagation-through-time (allows freeing buffers after backwards call)
      hx = hx.detach()
      cx = cx.detach()
    model.sample_noise()  # Pick a new noise vector (until next optimisation step)

    # Lists of outputs for training
    values, log_probs, rewards, entropies = [], [], [], []

    while not done and t - t_start < args.t_max:
      # Calculate policy and value
      policy, value, (hx, cx) = model(Variable(state), (hx, cx))
      log_policy = policy.log()
      entropy = -(log_policy * policy).sum(1)

      # Sample action
      action = policy.multinomial()
      log_prob = log_policy.gather(1, action.detach())  # Graph broken as loss for stochastic action calculated manually
      action = action.data[0, 0]

      # Step
      state, reward, done, _ = env.step(action)
      state = state_to_tensor(state)
      reward = args.reward_clip and min(max(reward, -1), 1) or reward  # Optionally clamp rewards
      done = done or episode_length >= args.max_episode_length

      # Save outputs for training
      [arr.append(el) for arr, el in zip((values, log_probs, rewards, entropies), (value, log_prob, reward, entropy))]

      # Increment counters
      t += 1
      T.increment()

    # Return R = 0 for terminal s or V(s_i; θ) for non-terminal s
    if done:
      R = Variable(torch.zeros(1, 1))
    else:
      _, R, _ = model(Variable(state), (hx, cx))
      R = R.detach()

    # Calculate n-step returns in forward view, stepping backwards from the last state
    trajectory_length = len(rewards)
    values, log_probs, entropies = torch.cat(values), torch.cat(log_probs), torch.cat(entropies)
    returns = Variable(torch.Tensor(trajectory_length + 1, 1))
    returns[-1] = R
    for i in reversed(range(trajectory_length)):
      # R ← r_i + γR
      returns[i] = rewards[i] + args.discount * returns[i + 1]
    # Advantage A = R - V(s_i; θ)
    A = returns[:-1] - values
    # dθ ← dθ - ∂A^2/∂θ
    value_loss = 0.5 * A ** 2  # Least squares error

    # dθ ← dθ + ∇θ∙log(π(a_i|s_i; θ))∙A
    policy_loss = -log_probs * A.detach()  # Policy gradient loss (detached from critic)
    # dθ ← dθ + β∙∇θH(π(s_i; θ))
    policy_loss -= args.entropy_weight * entropies.unsqueeze(1)  # Entropy maximisation loss
    # Zero shared and local grads
    optimiser.zero_grad()
    # Note that losses were defined as negatives of normal update rules for gradient descent
    (policy_loss + value_loss).sum().backward()
    # Gradient L2 normalisation
    nn.utils.clip_grad_norm(model.parameters(), args.max_gradient_norm, 2)

    # Transfer gradients to shared model and update
    _transfer_grads_to_shared_model(model, shared_model)
    optimiser.step()
    if not args.no_lr_decay:
      # Linearly decay learning rate
      _adjust_learning_rate(optimiser, max(args.lr * (args.T_max - T.value()) / args.T_max, 1e-32))

  env.close()

参考内容

  1. https://arxiv.org/pdf/1706.10295v1.pdf
  2. https://arxiv.org/abs/1602.01783
  3. https://github.com/openai/baselines
  4. https://github.com/Kaixhin/NoisyNet-A3C
  5. https://github.com/wenh123/NoisyNet-DQN/

你可能感兴趣的:(深度强化学习)