AdamW, LAMB: 大型预训练模型常用优化器

前言

按照时间上的迭代顺序,近些年神经网络先后出现了 Gradient Descent (GD)、Momentum、Adaptive Gradient (AdaGrad)、Root Mean Square prop (RMSprop)、Adaptive Moment estimation (Adam) 等优秀的优化器。到如今,大部分 NLP 预训练模型已不再使用这些方法,而是使用 Adam Weight Decay Regularization (AdamW) 和去年首度亮相的 Layer-wise Adaptive Moments optimizer for Batching training (LAMB)。为何最为传统的 GD,包括衍生的 stochastic GD、mini-batch GD 优化器已不再使用,下文会有详细的介绍。

Gradient Descent (GD)

梯度下降法是最为经典的凸优化优化器,思想也非常明确:通过 loss 反向传导计算参数的梯度,参数往哪个方向跑可以让 loss 下降,就让参数往哪个方向更新:
Δ W k = ∂ l o s s ∂ W k = ∂ l o s s ∂ Z n ∂ Z n ∂ Z n − 1 . . . ∂ Z k + 1 ∂ W k \Delta W_k=\frac{\partial loss}{\partial W_k}=\frac{\partial loss}{\partial Z_n}\frac{\partial Z_n}{\partial Z_{n-1}}...\frac{\partial Z_{k+1}}{\partial W_k} ΔWk=Wkloss=ZnlossZn1Zn...WkZk+1

W k ← W k − α Δ W k W_k\leftarrow W_k-\alpha \Delta W_k WkWkαΔWk

需要注意的是, W k W_k Wk 中的每一个浮点元素的梯度计算和梯度更新,相互之间是完全独立的,这对于理解梯度更新的机理非常重要。上式中, α \alpha α 为学习率,通常是一个固定的超参数,学习率越高,收敛越快。但需要注意控制范围。学习率过大,容易造成梯度跨过参数的局部最优点造成参数震荡;学习率过小,会导致训练过程过于漫长。为避免参数震荡,使用 GD 时,学习率通常设置在一个较低值,且训练的 batch_size 越大,学习率越低。梯度裁剪虽能一定程度上解决梯度震荡的问题,但由于输出的概率分布发生偏移,模型收敛也受到一定负面影响,因此需尽可能避免对梯度裁剪的依赖。

Adaptive Moment estimation (Adam)

为解决 GD 中固定学习率带来的不同参数间收敛速度不一致的弊端,AdaGrad 和 RMSprop 诞生出来,为每个参数赋予独立的学习率。计算梯度后,梯度较大的参数获得的学习率较低,反之亦然。此外,为避免每次梯度更新时都独立计算梯度,导致梯度方向持续变化,Momentum 将上一轮梯度值加入到当前梯度的计算中,通过某种权重对两者加权求和,获得当前批次参数更新的更新值。 Adam 结合了这两项考虑,既为每一个浮点参数自适应性地设置学习率,又将过去的梯度历史纳入考量:
m t = β 1 m t − 1 + ( 1 − β 1 ) Δ W m_t=\beta_1m_{t-1}+(1-\beta_1)\Delta W mt=β1mt1+(1β1)ΔW

v t = β 2 v t − 1 + ( 1 − β 2 ) Δ W 2 v_t=\beta_2v_{t-1}+(1-\beta_2)\Delta W^2 vt=β2vt1+(1β2)ΔW2

m t ^ = m t 1 − β 1 t \hat{m_t}=\frac{m_t}{1-\beta_1^t} mt^=1β1tmt

v t ^ = v t 1 − β 2 t \hat{v_t}=\frac{v_t}{1-\beta_2^t} vt^=1β2tvt

W t ← W t − 1 − α v t ^ + ϵ m t ^ W_t\leftarrow W_{t-1}-\frac{\alpha}{\sqrt{\hat{v_t}}+\epsilon}\hat{m_t} WtWt1vt^ +ϵαmt^

实际使用中,通常 β 1 = 0.9 \beta_1=0.9 β1=0.9 β 2 > 0.9 \beta_2>0.9 β2>0.9。BERT 源代码中,预训练的 β 2 \beta_2 β2 为 0.98,微调的 β 2 \beta_2 β2 为 0.999,其目的是为了减少对预训练中得到的原始参数结构的破坏,使收敛更为平缓。此外, m 0 m_0 m0 v 0 v_0 v0 皆为初始化得来,因此训练时参数种子的设置往往对模型结果的影响较大。从上述公式可以看出,训练前期的学习率和梯度更新是比较激进的,到后期逐渐平稳。

虽然 Adam 优化器的使用会导致内存中多出两倍于原参数体量的占用,但与之换来的训练收益使得学术界并没有放弃这一高效的方法。

Adam Weight Decay Regularization (AdamW)

Adam 虽然收敛速度快,但没能解决参数过拟合的问题。学术界讨论了诸多方案,其中包括在损失函数中引入参数的 L2 正则项。这样的方法在其他的优化器中或许有效,但会因为 Adam 中自适应学习率的存在而对使用 Adam 优化器的模型失效。AdamW 的出现便是为了解决这一问题,达到同样使参数接近于 0 的目的。具体的举措,是在最终的参数更新时引入参数自身:
m t = β 1 m t − 1 + ( 1 − β 1 ) Δ W m_t=\beta_1m_{t-1}+(1-\beta_1)\Delta W mt=β1mt1+(1β1)ΔW

v t = β 2 v t − 1 + ( 1 − β 2 ) Δ W 2 v_t=\beta_2v_{t-1}+(1-\beta_2)\Delta W^2 vt=β2vt1+(1β2)ΔW2

m t ^ = m t 1 − β 1 t \hat{m_t}=\frac{m_t}{1-\beta_1^t} mt^=1β1tmt

v t ^ = v t 1 − β 2 t \hat{v_t}=\frac{v_t}{1-\beta_2^t} vt^=1β2tvt

W t ← W t − 1 − α ( m t ^ v t ^ + ϵ + λ W t − 1 ) W_t\leftarrow W_{t-1}-\alpha\big(\frac{\hat{m_t}}{\sqrt{\hat{v_t}}+\epsilon}+\lambda W_{t-1}\big) WtWt1α(vt^ +ϵmt^+λWt1)

λ \lambda λ 即为权重衰减因子,常见的设置为 0.005/0.01。这一优化策略目前正广泛应用于各大预训练语言模型。

Layer-wise Adaptive Moments optimizer for Batching training (LAMB)

LAMB 优化器是 2019 年出现的一匹新秀,原论文标题后半部分叫做 “Training BERT in 76 Minutes”,足以看出其野心之大。 LAMB 出现的目的是加速预训练进程,这个优化器也成为 NLP 社区为泛机器学习领域做出的一大贡献。在使用 Adam 和 AdamW 等优化器时,一大问题在于 batch size 存在一定的隐式上限,一旦突破这个上限,梯度更新极端的取值会导致自适应学习率调整后极为困难的收敛,从而无法享受增加的 batch size 带来的提速增益。LAMB 优化器的作用便在于使模型在进行大批量数据训练时,能够维持梯度更新的精度:
m t = β 1 m t − 1 + ( 1 − β 1 ) Δ W m_t=\beta_1m_{t-1}+(1-\beta_1)\Delta W mt=β1mt1+(1β1)ΔW

v t = β 2 v t − 1 + ( 1 − β 2 ) Δ W 2 v_t=\beta_2v_{t-1}+(1-\beta_2)\Delta W^2 vt=β2vt1+(1β2)ΔW2

r t = m t v t + ϵ r_t=\frac{m_t}{\sqrt{v_t}+\epsilon} rt=vt +ϵmt

W t ← W t − 1 − α ⋅ ϕ ( ∣ ∣ W t − 1 ∣ ∣ ∣ ∣ r t + λ W t − 1 ∣ ∣ ) ( r t + λ W t − 1 ) W_t\leftarrow W_{t-1}-\alpha\cdot\phi\big(\frac{||W_{t-1}||}{||r_t+\lambda W_{t-1}||}\big)(r_t+\lambda W_{t-1}) WtWt1αϕ(rt+λWt1Wt1)(rt+λWt1)

其中, ϕ \phi ϕ 是一个可选择的映射函数,一种是 ϕ ( z ) = z \phi(z)=z ϕ(z)=z,另一种则为起到归一化作用的 ϕ ( z ) = min ⁡ ( max ⁡ ( z , γ l ) , γ u ) \phi(z)=\min(\max(z, \gamma_l),\gamma_u) ϕ(z)=min(max(z,γl),γu) γ l \gamma_l γl γ u \gamma_u γu 为预先设定的超参数,分别代表参数调整的下界和上界。这一简单的调整所带来的实际效果非常显著。使用 AdamW 时,batch size 超过 512 便会导致模型效果大幅下降,但在 LAMB 下,batch size 可以直接提到 32,000 而不会导致精度损失。

由于在下游微调预训练模型时,通常无需过大的数据集,因而 LAMB 仅在预训练环节使用。遗憾的是,LAMB 在 batch size 512 以下时无法起到显著作用,目前只能作为大体量财团的工具。

附录

以下是 LAMB 优化器的 tensorflow1.x 代码,可作为参考以理解算法,具体的代码出处已无法找寻。

class LAMBOptimizer(tf.train.Optimizer):
    '''
    LAMBOptimizer optimizer.
	
	# Important Note
		- This is NOT an official implementation.
		- LAMB optimizer is changed from arXiv v1 ~ v3.
		- We implement v3 version (which is the latest version on June, 2019.).
		- Our implementation is based on `AdamWeightDecayOptimizer` in BERT (provided by Google).
    # References
		- LAMB optimier: https://github.com/ymcui/LAMB_Optimizer_TF
		- Large Batch Optimization for Deep Learning: Training BERT in 76 minutes. https://arxiv.org/abs/1904.00962v3
		- BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. https://arxiv.org/abs/1810.04805
    # Parameters
		- There is nothing special, just the same as `AdamWeightDecayOptimizer`.
    '''
    def __init__(self,
                 learning_rate,
                 weight_decay_rate=0.01,
                 beta_1=0.9,
                 beta_2=0.999,
                 epsilon=1e-6,
                 exclude_from_weight_decay=None,
                 name="LAMBOptimizer"):
        """Constructs a LAMBOptimizer."""
        super(LAMBOptimizer, self).__init__(False, name)

        self.learning_rate = learning_rate
        self.weight_decay_rate = weight_decay_rate
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.epsilon = epsilon
        self.exclude_from_weight_decay = exclude_from_weight_decay

    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        """See base class."""
        assignments = []
        for (grad, param) in grads_and_vars:
            if grad is None or param is None:
                continue

            param_name = self._get_variable_name(param.name)

            m = tf.get_variable(
                name=param_name + "/lamb_m",
                shape=param.shape.as_list(),
                dtype=tf.float32,
                trainable=False,
                initializer=tf.zeros_initializer())
            v = tf.get_variable(
                name=param_name + "/lamb_v",
                shape=param.shape.as_list(),
                dtype=tf.float32,
                trainable=False,
                initializer=tf.zeros_initializer())

            # Standard Adam update.
            next_m = (
                    tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
            next_v = (
                    tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
                                                              tf.square(grad)))

            update = next_m / (tf.sqrt(next_v) + self.epsilon)

            # Just adding the square of the weights to the loss function is *not*
            # the correct way of using L2 regularization/weight decay with Adam,
            # since that will interact with the m and v parameters in strange ways.
            #
            # Instead we want ot decay the weights in a manner that doesn't interact
            # with the m/v parameters. This is equivalent to adding the square
            # of the weights to the loss with plain (non-momentum) SGD.
            if self._do_use_weight_decay(param_name):
                update += self.weight_decay_rate * param

            ############## BELOW ARE THE SPECIFIC PARTS FOR LAMB ##############

            # Note: Here are two choices for scaling function \phi(z)
            # minmax:   \phi(z) = min(max(z, \gamma_l), \gamma_u)
            # identity: \phi(z) = z
            # The authors does not mention what is \gamma_l and \gamma_u
            # UPDATE: after asking authors, they provide me the code below.
            # ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where(
            #      math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0)

            r1 = tf.sqrt(tf.reduce_sum(tf.square(param)))
            r2 = tf.sqrt(tf.reduce_sum(tf.square(update)))

            r = tf.where(tf.greater(r1, 0.0),
                         tf.where(tf.greater(r2, 0.0),
                                  r1 / r2,
                                  1.0),
                         1.0)

            eta = self.learning_rate * r

            update_with_lr = eta * update

            next_param = param - update_with_lr

            assignments.extend(
                [param.assign(next_param),
                 m.assign(next_m),
                 v.assign(next_v)])
        return tf.group(*assignments, name=name)

    def _do_use_weight_decay(self, param_name):
        """Whether to use L2 weight decay for `param_name`."""
        if not self.weight_decay_rate:
            return False
        if self.exclude_from_weight_decay:
            for r in self.exclude_from_weight_decay:
                if re.search(r, param_name) is not None:
                    return False
        return True

    def _get_variable_name(self, param_name):
        """Get the variable name from the tensor name."""
        m = re.match("^(.*):\\d+$", param_name)
        if m is not None:
            param_name = m.group(1)
        return param_name

你可能感兴趣的:(自然语言处理,深度学习,算法)