机器学习笔记之正则化(三)权重衰减角度(偏差方向)

机器学习笔记之正则化——权重衰减角度[偏差方向]

  • 引言
    • 回顾:关于目标函数中的 λ , C \lambda,\mathcal C λ,C
    • 正则化与非正则化之间的偏差
    • 偏差的计算过程

引言

上一节从直观现象的角度观察权重 W \mathcal W W是如何出现权重衰减的,并且介绍了 W \mathcal W W的权重衰减是如何抑制过拟合的发生。本节从偏差方向观察权重衰减

回顾:关于目标函数中的 λ , C \lambda,\mathcal C λ,C

回顾基于拉格朗日乘数法角度的优化问题表示如下:
这里以 L 2 L_2 L2正则化为例。
{ L ( W , λ ) = J ( W ) + λ ( ∣ ∣ W ∣ ∣ 2 − C ) s . t . λ > 0 \begin{cases} \mathcal L(\mathcal W,\lambda) = \mathcal J(\mathcal W) + \lambda (||\mathcal W||_2 - \mathcal C) \\ s.t. \quad \lambda > 0 \end{cases} {L(W,λ)=J(W)+λ(∣∣W2C)s.t.λ>0
目标函数 L ( W , λ ) \mathcal L(\mathcal W,\lambda) L(W,λ)展开,可得到如下形式:
L ( W , λ ) = J ( W ) + λ ∣ ∣ W ∣ ∣ 2 ⏟ 标准正则化 − λ ⋅ C \mathcal L(\mathcal W,\lambda) = \underbrace{\mathcal J(\mathcal W) + \lambda ||\mathcal W||_2}_{标准正则化} - \lambda \cdot \mathcal C L(W,λ)=标准正则化 J(W)+λ∣∣W2λC
其中 C = W T W \mathcal C = \sqrt{\mathcal W^T\mathcal W} C=WTW ,也就是 L 2 L_2 L2正则化范围的半径;而 λ \lambda λ表示拉格朗日参数

关于 λ , C \lambda,\mathcal C λ,C,无论我们修改哪一个参数,最终都会影响正则化在权重空间中的 有效范围

  • λ \lambda λ确定, C \mathcal C C发生变化时,由于 λ ⋅ C \lambda \cdot \mathcal C λC项的目的就是为了在梯度下降每次迭代过程中,找到一个大小相等、方向相反的向量。 λ ⋅ C \lambda \cdot \mathcal C λC就是调节大小的。

    λ \lambda λ确定条件下,仅在当前迭代步骤中, C \mathcal C C也被确定,从而正则化范围被确定。那么这个大小相等、方向相反的向量在权重空间中一定时当前正则化范围中的最优解。这个解大概率会在范围的边缘上,与 J ( W ) \mathcal J(\mathcal W) J(W)在权重空间中的等高线相切

  • C \mathcal C C确定, λ \lambda λ不确定时,也就是说权重范围被确定,我们假设通过不断尝试不同的 λ \lambda λ,使得它与当前迭代步骤的梯度向量大小相等、方向相反。这个值它仅能保证在正则化范围中。但不能确定它是当前迭代步骤的最优解——可能在正则化范围内存在权重点,它距离 J ( W ) \mathcal J(\mathcal W) J(W)的等高线中心更近。

在深度学习框架中,如 PyTorch \text{PyTorch} PyTorch。它的处理方式是第一种,我们手动设置 λ \lambda λ具体值,让神经网络自己调节正则化范围 C \mathcal C C
这里以 Adam \text{Adam} Adam算法为例。这里不深究,仅作一个关于正则化参数的描述。

from torch import optim as optim

Optimizer = optim.Adam(TrainModel.parameters(), lr=LearningRate, weight_decay=WeightDecay)

其中weight_decay参数就是参数 λ \lambda λ的取值。在模型训练过程中,每一次反向传播过程我们会更新 W \mathcal W W的信息,自然也会更新 C = W T W \mathcal C = \sqrt{\mathcal W^T\mathcal W} C=WTW 的信息。

正则化与非正则化之间的偏差

下图表示的是权重空间中某损失函数 J ( W ) \mathcal J(\mathcal W) J(W)等高线,而虚线表示某一梯度方向。
机器学习笔记之正则化(三)权重衰减角度(偏差方向)_第1张图片
上述箭头的指向表示权重最优权重优化的过程。值得注意的是,由于初始化权重 W i n i t \mathcal W_{init} Winit随机的,这意味着虚线/箭头不是唯一确定的。

并且并不是箭头最终指向的点就是最优解。实际上,在迭代过程中,只要 W i n i t \mathcal W_{init} Winit被确定下来,那么该 W i n i t \mathcal W_{init} Winit J ( W ) \mathcal J(\mathcal W) J(W)等高线中心方向上的任意一点,都是对应每次迭代中的最优解

这里依然以 L 1 , L 2 L_1,L_2 L1,L2正则化为例。假设我们的梯度优化方向就是虚线方向,那么观察迭代过程中加入 L 1 , L 2 L_1,L_2 L1,L2正则化后最优解的路径未加入正则化情况下最优解的期望逻辑之间的关系:

  • 这里蓝色点表示添加正则化后的最优解路径;红色点表示未加入正则化后最优解的期望路径。
  • 该图片来自于文章下方的视频链接,下同,侵删。

机器学习笔记之正则化(三)权重衰减角度(偏差方向)_第2张图片
通过观察图像可以发现:无论是 L 1 L_1 L1还是 L 2 L_2 L2正则化,它们的路径均与期望路径之间存在偏差

这个偏差如何计算 ? ? ? 并且这个偏差对于权重衰减有什么关联关系 ? ? ?

偏差的计算过程

这里假设 W ∗ \mathcal W^* W表示损失函数 J ( W ) \mathcal J(\mathcal W) J(W)条件下产生的最优解。那么 J ( W ∗ ) \mathcal J(\mathcal W^*) J(W)就表示该最优解对应的损失函数的最优解
最优解表示‘最小训练误差’时的权重。

J ^ ( W ) \hat J(\mathcal W) J^(W)表示正则化后的损失函数,并假设该函数在 W ^ \hat {\mathcal W} W^中取得最优解 J ^ ( W ^ ) \hat J(\hat {\mathcal W}) J^(W^)。我们需要讨论的是: W ∗ \mathcal W^* W W ^ \hat {\mathcal W} W^之间的偏差是多少。从数学角度观察,我们希望通过公式表达 W ∗ \mathcal W^* W W ^ \hat {\mathcal W} W^之间的函数关系/关联关系

  • 我们通过泰勒公式将损失函数 J ( W ) \mathcal J(\mathcal W) J(W)进行近似表达。这里仅将其扩展至二次
    泰勒公式的标准式进行比较,这里将常数 a 0 = W ∗ a_0 = \mathcal W^* a0=W,真正的变量只有 W \mathcal W W.
    J ( W ) ≈ J ( W ∗ ) + 1 1 ! ∇ W J ( W ∗ ) ⋅ ( W − W ∗ ) + 1 2 ! J ′ ′ ( W ∗ ) ⋅ ( W − W ∗ ) 2 = J ( W ∗ ) + ∇ W J ( W ∗ ) + 1 2 ( W − W ∗ ) T H ( W − W ∗ ) \begin{aligned} \mathcal J(\mathcal W) & \approx \mathcal J(\mathcal W^*) + \frac{1}{1!}\nabla_{\mathcal W} \mathcal J(\mathcal W^*) \cdot (\mathcal W - \mathcal W^*) + \frac{1}{2!} \mathcal J''(\mathcal W^*) \cdot (\mathcal W - \mathcal W^*)^2 \\ & = \mathcal J(\mathcal W^*) + \nabla_{\mathcal W} \mathcal J(\mathcal W^*) + \frac{1}{2} (\mathcal W - \mathcal W^*)^T \mathcal H (\mathcal W - \mathcal W^*) \end{aligned} J(W)J(W)+1!1WJ(W)(WW)+2!1J′′(W)(WW)2=J(W)+WJ(W)+21(WW)TH(WW)
    其中 H \mathcal H H表示 Hession \text{Hession} Hession矩阵,它表示损失函数关于 W ∗ \mathcal W^* W二阶导结果。上式中,由于 J ( W ∗ ) \mathcal J(\mathcal W^*) J(W)是最值点,因而 ∇ W J ( W ∗ ) = 0 \nabla_{\mathcal W} \mathcal J(\mathcal W^*) = 0 WJ(W)=0。最终可将上式写成如下形式:
    J ( W ) ≈ J ( W ∗ ) + 1 2 ( W − W ∗ ) T H ( W − W ∗ ) \mathcal J(\mathcal W) \approx \mathcal J(\mathcal W^*) +\frac{1}{2} (\mathcal W - \mathcal W^*)^T \mathcal H (\mathcal W - \mathcal W^*) J(W)J(W)+21(WW)TH(WW)

  • 对上述损失函数 J ( W ) \mathcal J(\mathcal W) J(W)关于变量 W \mathcal W W求解梯度:

    • 由于 W ∗ \mathcal W^* W已知,那么 J ( W ∗ ) \mathcal J(\mathcal W^*) J(W)也是常数,其梯度为 0 0 0.
    • 这里用到矩阵求导公式.
      ∂ ∂ W [ ( W − W ∗ ) T H ( W − W ∗ ) ] = ∂ ∂ ( W − W ∗ ) [ ( W − W ∗ ) T H ( W − W ∗ ) ] ⋅ ∂ ( W − W ∗ ) ∂ W = 2 ⋅ H ( W − W ∗ ) ⋅ ( 1 − 0 ) = 2 ⋅ H ( W − W ∗ ) ∇ W J ( W ) = ∇ W J ( W ∗ ) ⏟ = 0 + ∇ W [ 1 2 ( W − W ∗ ) T H ( W − W ∗ ) ] = 1 2 ⋅ 2 ⋅ H ( W − W ∗ ) = H ( W − W ∗ ) \begin{aligned} \frac{\partial}{\partial \mathcal W} \left[(\mathcal W - \mathcal W^*)^T \mathcal H (\mathcal W - \mathcal W^*)\right] & = \frac{\partial}{\partial (\mathcal W - \mathcal W^*)}\left[(\mathcal W - \mathcal W^*)^T \mathcal H (\mathcal W - \mathcal W^*)\right] \cdot \frac{\partial (\mathcal W - \mathcal W^*)}{\partial \mathcal W}\\ & = 2 \cdot \mathcal H(\mathcal W - \mathcal W^*) \cdot (1 - 0) \\ & = 2 \cdot \mathcal H(\mathcal W - \mathcal W^*)\\ \nabla_{\mathcal W} \mathcal J(\mathcal W) & = \underbrace{\nabla_{\mathcal W} \mathcal J(\mathcal W^*)}_{=0} + \nabla_{\mathcal W} \left[\frac{1}{2} (\mathcal W - \mathcal W^*)^T \mathcal H (\mathcal W - \mathcal W^*)\right] \\ & = \frac{1}{2} \cdot 2 \cdot \mathcal H(\mathcal W - \mathcal W^*) \\ & = \mathcal H(\mathcal W - \mathcal W^*) \end{aligned} W[(WW)TH(WW)]WJ(W)=(WW)[(WW)TH(WW)]W(WW)=2H(WW)(10)=2H(WW)==0 WJ(W)+W[21(WW)TH(WW)]=212H(WW)=H(WW)
  • 上述结果是未使用正则化条件下关于 W \mathcal W W梯度结果。这里依然以 L 2 L_2 L2正则化为例,观察正则化后的损失函数 J ^ ( W ) \hat {\mathcal J}(\mathcal W) J^(W)的结果以及梯度结果表示为如下形式:
    { J ^ ( W ) = J ( W ∗ ) + 1 2 ( W − W ∗ ) T H ( W − W ∗ ) + α 2 W T W ⏟ 正则化项 ∇ W J ^ ( W ) = H ( W − W ∗ ) + α 2 ⋅ 2 ⋅ W = H ( W − W ∗ ) + α ⋅ W \begin{cases} \begin{aligned} \hat {\mathcal J}(\mathcal W) = \mathcal J(\mathcal W^*) + \frac{1}{2} (\mathcal W - \mathcal W^*)^T \mathcal H(\mathcal W - \mathcal W^*) + \underbrace{\frac{\alpha}{2} \mathcal W^T\mathcal W}_{正则化项} \end{aligned} \\ \begin{aligned} \nabla_{\mathcal W} \hat {\mathcal J}(\mathcal W) & = \mathcal H(\mathcal W - \mathcal W^*) + \frac{\alpha}{2} \cdot 2 \cdot \mathcal W \\ & = \mathcal H(\mathcal W - \mathcal W^*) + \alpha \cdot \mathcal W \end{aligned} \end{cases} J^(W)=J(W)+21(WW)TH(WW)+正则化项 2αWTWWJ^(W)=H(WW)+2α2W=H(WW)+αW

  • 由于 W ^ \hat {\mathcal W} W^使损失函数 J ^ ( W ) \hat {\mathcal J}(\mathcal W) J^(W)取得最小值的解。这意味着 ∇ W J ^ ( W ) ∣ W = W ^ = 0 \nabla_{\mathcal W} \hat {\mathcal J}(\mathcal W) |_{\mathcal W = \hat {\mathcal W}} = 0 WJ^(W)W=W^=0。则有:
    这里 I \mathcal I I表示单位矩阵。
    H ( W ^ − W ∗ ) + α ⋅ W ^ = 0 ⇒ ( H + α ⋅ I ) W ^ = H ⋅ W ∗ ⇒ W ^ = ( H + α ⋅ I ) − 1 H ⋅ W ∗ \begin{aligned} & \quad \mathcal H(\hat {\mathcal W} - \mathcal W^*) + \alpha \cdot \hat {\mathcal W} = 0 \\ & \Rightarrow(\mathcal H + \alpha \cdot \mathcal I) \hat {\mathcal W} = \mathcal H \cdot \mathcal W^* \\ & \Rightarrow \hat {\mathcal W} = (\mathcal H + \alpha \cdot \mathcal I)^{-1} \mathcal H \cdot \mathcal W^* \end{aligned} H(W^W)+αW^=0(H+αI)W^=HWW^=(H+αI)1HW

  • 关于 Hession \text{Hession} Hession矩阵的性质,如果 J ( W ) \mathcal J(\mathcal W) J(W)权重空间连续可导,那么 H \mathcal H H是一个对阵矩阵。对 H \mathcal H H进行特征值分解 H = Q Λ Q T \mathcal H = \mathcal Q \Lambda \mathcal Q^T H=QΛQT,并将该结果代入到上式中:

    • 其中 Q \mathcal Q Q是一个正交矩阵; Λ \Lambda Λ是对角矩阵,其对角线上元素是 H \mathcal H H的特征值。
    • 关于 Q \mathcal Q Q的性质,这里用到的有: Q T = Q − 1 ⇒ Q T Q = Q Q T = 1 \mathcal Q^T = \mathcal Q^{-1} \Rightarrow \mathcal Q^T\mathcal Q = \mathcal Q \mathcal Q^T = 1 QT=Q1QTQ=QQT=1
    • α ⋅ I \alpha \cdot \mathcal I αI同样也是一个对称矩阵,同样也可以进行特征值分解。该部分在线性回归——岭回归中存在相似的步骤。
    • 上述公式对应《深度学习(花书)》P143 7.1 参数范围惩罚 7.13
      W ^ = ( Q Λ Q T + α ⋅ I ) − 1 Q Λ Q T ⋅ W ∗ = [ Q Λ Q T + Q ( α ⋅ I ) Q T ] − 1 ⋅ W ∗ = [ Q ( Λ + α ⋅ I ) Q T ] − 1 Q Λ Q T ⋅ W ∗ = ( Q T ) − 1 ⏟ = Q ( Λ + α ⋅ I ) − 1 Q − 1 Q ⏟ = I Λ Q T ⋅ W ∗ = Q ( Λ + α ⋅ I ) − 1 Λ Q T ⋅ W ∗ \begin{aligned} \hat {\mathcal W} & = (\mathcal Q \Lambda \mathcal Q^T + \alpha \cdot \mathcal I)^{-1} \mathcal Q \Lambda\mathcal Q^T \cdot \mathcal W^* \\ & = \left[\mathcal Q \Lambda \mathcal Q^T + \mathcal Q(\alpha \cdot \mathcal I) \mathcal Q^T\right]^{-1} \cdot \mathcal W^* \\ & = [\mathcal Q (\Lambda + \alpha \cdot \mathcal I) \mathcal Q^T]^{-1} \mathcal Q \Lambda \mathcal Q^T \cdot \mathcal W^* \\ & = \underbrace{(\mathcal Q^T)^{-1}}_{=\mathcal Q}(\Lambda + \alpha \cdot \mathcal I)^{-1} \underbrace{\mathcal Q^{-1} \mathcal Q}_{=\mathcal I} \Lambda \mathcal Q^T \cdot \mathcal W^* \\ & = \mathcal Q(\Lambda + \alpha \cdot \mathcal I)^{-1} \Lambda \mathcal Q^T \cdot \mathcal W^* \end{aligned} W^=(QΛQT+αI)1QΛQTW=[QΛQT+Q(αI)QT]1W=[Q(Λ+αI)QT]1QΛQTW==Q (QT)1(Λ+αI)1=I Q1QΛQTW=Q(Λ+αI)1ΛQTW
  • 观察上式中的 Λ + α ⋅ I \Lambda + \alpha \cdot \mathcal I Λ+αI,由于 Λ \Lambda Λ α ⋅ I \alpha \cdot \mathcal I αI都是对角阵,那么它们相加依然是对角阵。对应的逆矩阵 ( Λ + α ⋅ I ) − 1 (\Lambda + \alpha \cdot \mathcal I)^{-1} (Λ+αI)1就是主对角阵各元素取倒数的结果。由于后面乘了一个 Λ \Lambda Λ,那么取倒数后的分子就是 Λ \Lambda Λ中的各元素。至此, W ^ \hat {\mathcal W} W^ W ∗ \mathcal W^* W之间满足如下关联关系:

    • 这里我们不关心正交矩阵 Q \mathcal Q Q,它仅是一组正交基,无论在哪一组正交基下,不影响 W ∗ \mathcal W^* W W ^ \hat {\mathcal W} W^之间的关系.
    • 其中 W ^ i \hat {\mathcal W}_i W^i表示 W ^ \hat {\mathcal W} W^的第 i i i个分量;同理, W i ∗ \mathcal W_i^* Wi表示 W ∗ \mathcal W^* W的第 i i i个分量。 λ i \lambda_i λi表示对角矩阵 Λ \Lambda Λ i i i行第 i i i列的元素。
      W ^ i = λ i λ i + α ⋅ W i ∗ \hat {\mathcal W}_i = \frac{\lambda_i}{\lambda_i + \alpha} \cdot \mathcal W_i^* W^i=λi+αλiWi
  • 观察上式,显然 W ^ \hat {\mathcal W} W^ W ∗ \mathcal W^* W之间取决于 α \alpha α的结果:
    需要注意的是,在正则化——权重衰减(直观印象)中介绍过, α \alpha α不是的 λ \lambda λ,它内部包含了 W T W − C ⇒ W T W − C 2 \sqrt{\mathcal W^T\mathcal W} - \mathcal C \Rightarrow \mathcal W^T\mathcal W - \mathcal C^2 WTW CWTWC2的多余信息.因此 α \alpha α的取值是不确定的。

    • α = 0 \alpha = 0 α=0时, W T W \mathcal W^T\mathcal W WTW对应的系数为 0 ⇒ W ^ = W ∗ 0 \Rightarrow \hat {\mathcal W} = \mathcal W^* 0W^=W
    • α > 0 \alpha > 0 α>0时, W ^ < W ∗ \hat {\mathcal W} < \mathcal W^* W^<W
    • α < 0 \alpha < 0 α<0时, W ^ > W ∗ \hat {\mathcal W} > \mathcal W^* W^>W

可以看出,在 L 2 L_2 L2正则化中,本质上时通过 α \alpha α W ∗ \mathcal W^* W进行缩放得到的正则化权重结果,这也是权重衰减的本质。

相关参考:
《深度学习(花书)》7.1 参数范数惩罚
“L1和L2正则化”直观理解(之二),为什么又叫权重衰减?到底哪里衰减了?

你可能感兴趣的:(机器学习,算法八股查漏补缺,深度学习,机器学习,人工智能,深度学习,正则化,权重衰减)