清华大学出品:罚梯度范数提高深度学习模型泛化性

1 引言

神经网络结构简单,训练样本量不足,则会导致训练出来的模型分类精度不高;神经网络结构复杂,训练样本量过大,则又会导致模型过拟合,所以如何训练神经网络提高模型的泛化性是人工智能领域一个非常核心的问题。最近读到了一篇与该问题相关的文章,论文中作者在训练过程中通过在损失函数中增加正则化项梯度范数的约束从而来提高深度学习模型的泛化性。作者从原理和实验两方面分别对论文中的方法进行了详细地阐述和验证。 L i p s c h i t z \mathrm{Lipschitz} Lipschitz连续是对深度学习进行理论分析中非常重要且常见的数学工具,该论文就是以神经网络损失函数 是 L i p s c h i t z 是\mathrm{Lipschitz} Lipschitz连续为出发点进行数学推导。为了方便读者能够更流畅地欣赏论文作者漂亮的数学证明思路和过程,本文对于论文中没有展开数学证明细节进行了补充。
清华大学出品:罚梯度范数提高深度学习模型泛化性_第1张图片
论文链接:https://arxiv.org/abs/2202.03599

2 L i p s c h i z \mathrm{Lipschiz} Lipschiz连续

给定一个训练数据集 S = { ( x i , y i ) } i = 0 n \mathcal{S}=\{(x_i,y_i)\}_{i=0}^n S={(xi,yi)}i=0n服从分布 D \mathcal{D} D,一个带有参数 θ ∈ Θ \theta \in \Theta θΘ的神经网络 f ( ⋅ ; θ ) f(\cdot;\theta) f(;θ),损失函数为 L S = 1 N ∑ i = 1 N l ( y i , y i , θ ^ ) L_{\mathcal{S}}=\frac{1}{N}\sum\limits_{i=1}^N l(\hat{y_i,y_i ,\theta}) LS=N1i=1Nl(yi,yi,θ^)当需要对损失函数中的梯度范数进行约束时,则有如下损失函数 L ( θ ) = L S + λ ⋅ ∥ ∇ θ L S ( θ ) ∥ p L(\theta)=L_{\mathcal{S}}+\lambda \cdot \|\nabla_\theta L_{\mathcal{S}}(\theta)\|_p L(θ)=LS+λθLS(θ)p其中 ∥ ⋅ ∥ p \|\cdot \|_p p表示 p p p范数, λ ∈ R + \lambda\in \mathbb{R}^{+} λR+为梯度惩罚系数。一般情况下,损失函数引入梯度的正则化项会使得其在优化过程中在局部有更小的 L i p s c h i t z \mathrm{Lipschitz} Lipschitz常数, L i p s c h i t z \mathrm{Lipschitz} Lipschitz常数越小,就意味着损失函数就越平滑,平损失函数平滑区域易于损失函数优化权重参数。进而会使得训练出来的深度学习模型有更好的泛化性。
 深度学习中一个非常重要而且常见的概念就是 L i p s c h i t z \mathrm{Lipschitz} Lipschitz连续。给定一个空间 Ω ⊂ R n \Omega \subset \mathbb{R}^n ΩRn,对于函数 h : Ω → R m h:\Omega \rightarrow \mathbb{R}^m h:ΩRm,如果存在一个常数 K K K,对于 ∀ θ 1 , θ 2 ∈ Ω \forall \theta_1,\theta_2 \in \Omega θ1,θ2Ω满足以下条件则称 L i p s c h i t z \mathrm{Lipschitz} Lipschitz连续 ∥ h ( θ 1 ) − h ( θ 2 ) ∥ 2 ≤ K ⋅ ∥ θ 1 − θ 2 ∥ 2 \|h(\theta_1)-h(\theta_2)\|_2 \le K \cdot \|\theta_1 - \theta_2\|_2 h(θ1)h(θ2)2Kθ1θ22其中 K K K表示的是 L i p s c h i t z \mathrm{Lipschitz} Lipschitz常数。如果对于参数空间 Θ ⊂ Ω \Theta \subset \Omega ΘΩ,如果 Θ \Theta Θ有一个邻域 A \mathcal{A} A,且 h ∣ A h|_{\mathcal{A}} hA L i p s c h i t z \mathrm{Lipschitz} Lipschitz连续,则称 h h h是局部 L i p s c h i t z \mathrm{Lipschitz} Lipschitz连续。直观来看, L i p s c h i t z \mathrm{Lipschitz} Lipschitz常数描述的是输出关于输入变化速率的一个上界。对于一个小的 L i p s c h i t z \mathrm{Lipschitz} Lipschitz参数,在邻域 A \mathcal{A} A中给定任意两个点,它们输出的改变被限制在一个小的范围里。
 根据微分中值定理,给定一个最小值点 θ i \theta_i θi,对于任意点 ∀ θ i ′ ∈ A \forall \theta_i^{\prime}\in \mathcal{A} θiA,则有如下公式成立 ∥ ∣ L ( θ i ′ ) − L ( θ i ) ∥ 2 = ∥ ∇ L ( ζ ) ( θ i ′ − θ i ) ∥ 2 \||L(\theta_i^{\prime})-L(\theta_i)\|_2 = \|\nabla L (\zeta) (\theta_i^{\prime}-\theta_i)\|_2 L(θi)L(θi)2=L(ζ)(θiθi)2其中 ζ = c θ i + ( 1 − c ) θ i ′ , c ∈ [ 0 , 1 ] \zeta=c \theta_i + (1-c)\theta^\prime_i, c \in [0,1] ζ=cθi+(1c)θi,c[0,1],根据 C a u c h y - S c h w a r z \mathrm{Cauchy\text{-}Schwarz} Cauchy-Schwarz不等式可知 ∥ ∣ L ( θ i ′ ) − L ( θ i ) ∥ 2 ≤ ∥ ∇ L ( ζ ) ∥ 2 ∥ ( θ i ′ − θ i ) ∥ 2 \||L(\theta_i^{\prime})-L(\theta_i)\|_2 \le \|\nabla L (\zeta)\|_2 \|(\theta_i^{\prime}-\theta_i)\|_2 L(θi)L(θi)2L(ζ)2(θiθi)2 θ i ′ → θ \theta_i^{\prime}\rightarrow \theta θiθ时,相应的 L i p s c h i z \mathrm{Lipschiz} Lipschiz常数接近 ∥ ∇ L ( θ i ) ∥ 2 \|\nabla L(\theta_i)\|_2 L(θi)2。因此可以通过减小 ∥ ∇ L ( θ i ) ∥ \|\nabla L(\theta_i)\| L(θi)的数值使得模型能够更平滑的收敛。

3 论文方法

对带有梯度范数约束的损失函数求梯度可得
∇ θ L ( θ ) = ∇ θ L S ( θ ) + ∇ θ ( λ ⋅ ∥ ∇ θ L S ( θ ) ∥ p ) \nabla_\theta L(\theta)=\nabla_\theta L_{\mathcal{S}}(\theta)+\nabla_\theta(\lambda \cdot \|\nabla_\theta L_{\mathcal{S}}(\theta)\|_p) θL(θ)=θLS(θ)+θ(λθLS(θ)p)在本文中,作者令 p = 2 p=2 p=2,此时则有如下推导过程 ∇ θ ∥ ∇ θ L S ( θ ) ∥ 2 = ∇ θ [ ∇ θ ⊤ L S ( θ ) ⋅ ∇ θ L S ( θ ) ] 1 2 = 1 2 ⋅ ∇ θ 2 L S ( θ ) ∇ θ L S ( θ ) ∥ ∇ θ L S ( θ ) ∥ 2 \begin{aligned}\nabla_\theta \|\nabla_\theta L_\mathcal{S}(\theta)\|_2&=\nabla_\theta[\nabla^{\top}_\theta L_{\mathcal{S}}(\theta)\cdot \nabla_\theta L_\mathcal{S}(\theta)]^{\frac{1}{2}}\\&=\frac{1}{2}\cdot \nabla^2_\theta L_{\mathcal{S}}(\theta)\frac{\nabla_\theta L_\mathcal{S}(\theta)}{\|\nabla_\theta L_\mathcal{S}(\theta)\|_2}\end{aligned} θθLS(θ)2=θ[θLS(θ)θLS(θ)]21=21θ2LS(θ)θLS(θ)2θLS(θ)将该结果带入到梯度范数约束的损失函数中,则有以下公式
∇ θ L ( θ ) = ∇ θ L S ( θ ) + λ ⋅ ∇ θ 2 L S ( θ ) ∇ θ L S ( θ ) ∥ ∇ θ L S ( θ ) ∥ 2 \nabla_\theta L(\theta)=\nabla_\theta L_{\mathcal{S}}(\theta)+\lambda \cdot \nabla^2_\theta L_{\mathcal{S}}(\theta) \frac{\nabla_\theta L_{\mathcal{S}}(\theta)}{\|\nabla_\theta L_{\mathcal{S}}(\theta)\|_2} θL(θ)=θLS(θ)+λθ2LS(θ)θLS(θ)2θLS(θ)可以发现,以上公式中涉及到 H e s s i a n \mathrm{Hessian} Hessian矩阵的计算,在深度学习中,计算参数的 H e s s i a n \mathrm{Hessian} Hessian矩阵会带来高昂的计算成本,所以需要用到一些近似的方法。作者将损失函数进行泰勒展开,其中令 H = ∇ θ 2 L S ( θ ) H=\nabla^2_\theta L_\mathcal{S}(\theta) H=θ2LS(θ),则有 L S ( θ + Δ θ ) = L S ( θ ) + ∇ θ ⊤ L S ( θ ) ⋅ Δ θ + 1 2 Δ θ ⊤ H Δ θ + O ( ∥ Δ θ ∥ 2 2 ) L_\mathcal{S}(\theta+\Delta \theta)=L_\mathcal{S}(\theta)+\nabla^{\top}_{\theta}L_\mathcal{S}(\theta)\cdot \Delta \theta + \frac{1}{2} \Delta \theta^{\top} H \Delta \theta +\mathcal{O}(\|\Delta \theta\|_2^2) LS(θ+Δθ)=LS(θ)+θLS(θ)Δθ+21ΔθHΔθ+O(Δθ22)进而则有 ∇ θ L S ( θ + Δ θ ) = ∇ Δ θ L S ( θ + Δ θ ) = ∇ θ L S ( θ ) + H Δ θ + O ( ∥ Δ θ ∥ 2 2 ) \begin{aligned}\nabla_\theta L_\mathcal{S}(\theta+\Delta \theta)&=\nabla_{\Delta\theta} L_\mathcal{S} (\theta + \Delta\theta)=\nabla_\theta L_{\mathcal{S}}(\theta)+ H \Delta \theta + \mathcal{O}(\|\Delta \theta\|^2_2)\end{aligned} θLS(θ+Δθ)=ΔθLS(θ+Δθ)=θLS(θ)+HΔθ+O(Δθ22)其中令 Δ θ = r v \Delta \theta=r v Δθ=rv r r r表示一个小的数值, v v v表示一个向量,带入上式则有 H v = ∇ θ L S ( θ + r v ) − ∇ θ L S ( θ ) r + O ( r ) H v =\frac{\nabla_\theta L_{\mathcal{S}}(\theta + r v)-\nabla_\theta L_{\mathcal{S}}(\theta)}{r}+\mathcal{O}(r) Hv=rθLS(θ+rv)θLS(θ)+O(r)如果令 v = ∇ θ L S ( θ ) ∥ ∇ θ L S ( θ ) ∥ v=\frac{\nabla_{\theta}L_\mathcal{S}(\theta)}{\|\nabla_\theta L_\mathcal{S}(\theta)\|} v=θLS(θ)θLS(θ),则有 H ∇ θ L S ( θ ) ∥ ∇ θ L S ( θ ) ∥ 2 ≈ ∇ θ L ( θ + r ∇ θ L S ( θ ) ∥ ∇ θ L S ( θ ) ∥ 2 ) − ∇ θ L ( θ ) r H \frac{\nabla_{\theta}L_{\mathcal{S}}(\theta)}{\|\nabla_\theta L_{\mathcal{S}}(\theta)\|_2}\approx \frac{\nabla_\theta L(\theta + r\frac{\nabla_\theta L_{\mathcal{S}}(\theta)}{\|\nabla_\theta L_{\mathcal{S}}(\theta)\|_2})-\nabla_\theta L(\theta)}{r} HθLS(θ)2θLS(θ)rθL(θ+rθLS(θ)2θLS(θ))θL(θ)
综上所述,经过整理可得
∇ θ L ( θ ) = ∇ θ L S ( θ ) + λ r ⋅ ( ∇ θ L S ( θ + r ∇ θ L S ( θ ) ∥ ∇ θ L S ( θ ) ∥ 2 ) − ∇ θ L S ( θ ) ) = ( 1 − α ) ∇ θ L S ( θ ) + α ∇ θ L S ( θ + r ∇ θ L S ( θ ) ∥ ∇ θ L S ( θ ) ∥ 2 ) \begin{aligned}\nabla_\theta L(\theta)&=\nabla_\theta L_\mathcal{S} (\theta)+\frac{\lambda}{r}\cdot (\nabla_\theta L_{\mathcal{S}}(\theta + r \frac{\nabla_\theta L_\mathcal{S}(\theta)}{\|\nabla_\theta L_\mathcal{S}(\theta)\|_2})-\nabla_\theta L_\mathcal{S}(\theta))\\&=(1-\alpha)\nabla_\theta L_\mathcal{S} (\theta)+\alpha \nabla_\theta L_\mathcal{S}(\theta+r \frac{\nabla_\theta L_\mathcal{S}(\theta)}{\|\nabla_\theta L_\mathcal{S}(\theta)\|_2})\end{aligned} θL(θ)=θLS(θ)+rλ(θLS(θ+rθLS(θ)2θLS(θ))θLS(θ))=(1α)θLS(θ)+αθLS(θ+rθLS(θ)2θLS(θ))其中 α = λ r \alpha=\frac{\lambda}{r} α=rλ,称 α \alpha α为平衡系数,取值范围为 0 ≤ α ≤ 1 0 \le \alpha \le 1 0α1。作者为了避免在近似计算梯度时,以上公式中的第二项链式法则求梯度需要计算 H e s s i a n \mathrm{Hessian} Hessian矩阵,做了以下的近似则有 ∇ θ L S ( θ + r ∇ θ L S ( θ ) ∥ ∇ θ L S ( θ ) ∥ 2 ) ≈ ∇ θ L S ( θ ) ∣ θ = θ + r ∇ θ L S ( θ ) ∥ ∇ θ L S ( θ ) ∥ 2 \nabla_\theta L_\mathcal{S}(\theta+r \frac{\nabla_\theta L_\mathcal{S}(\theta)}{\|\nabla_\theta L_\mathcal{S}(\theta)\|_2})\approx \nabla_\theta L_\mathcal{S} (\theta)|_{\theta =\theta +r \frac{\nabla_\theta L_\mathcal{S}(\theta)}{\|\nabla_\theta L_\mathcal{S}(\theta)\|_2}} θLS(θ+rθLS(θ)2θLS(θ))θLS(θ)θ=θ+rθLS(θ)2θLS(θ)以下算法流程图对本论文的训练方法进行汇总
清华大学出品:罚梯度范数提高深度学习模型泛化性_第2张图片

4 实验结果

下表表示的是在 C i f a r 10 \mathrm{Cifar10} Cifar10 C i f a r 100 \mathrm{Cifar100} Cifar100这两个数据集中不同 C N N \mathrm{CNN} CNN网络结构在标准训练, S A M \mathrm{SAM} SAM和本文的梯度约束这三种训练方法之间的测试错误率的比较。可以很直观的发现,本文提出的方法在绝大多数情况下测试错误率都是最低的,这也从侧面验证了经过论文方法的训练可以提高 C N N \mathrm{CNN} CNN模型的泛化性。
清华大学出品:罚梯度范数提高深度学习模型泛化性_第3张图片
论文作者也在当前非常热门的网络结构 V i s i o n   T r a n s f o r m e r \mathrm{Vision \text{ } Transformer} Vision Transformer进行了实验。下表表示的是在 C i f a r 10 \mathrm{Cifar10} Cifar10 C i f a r 100 \mathrm{Cifar100} Cifar100这两个数据集中不同 V i T \mathrm{ViT} ViT网络结构在标准训练, S A M \mathrm{SAM} SAM和本文的梯度约束这三种训练方法之间的测试错误率的比较。同理也可以发现本文提出的方法在所有情况下测试错误率都是最低的,这说明本文的方法也可以提到 V i s i o n   t r a n s f o r m e r \mathrm{Vision \text{ } transformer} Vision transformer模型的泛化性。
清华大学出品:罚梯度范数提高深度学习模型泛化性_第4张图片

你可能感兴趣的:(论文解读,深度学习,神经网络,人工智能)