SAM解析:Sharpness-Aware Minimization for Efficiently Improving Generalization

论文:Sharpness-Aware Minimization for Efficiently Improving Generalization( ICLR 2021)

一、理论

综合了另一篇论文:ASAM: Adaptive Sharpness-Aware Minimization
for Scale-Invariant Learning of Deep Neural Networks 对理论部分这边的解释,同时这篇论文自己也对SAM做出了改进。

1.要解决什么问题?

(1)目标函数

L Q ( w ) L_{\mathscr{Q}}(\boldsymbol{w}) LQ(w)是总体损失值,而训练集 L S ( w ) L_{\mathcal{S}}(\boldsymbol{w}) LS(w)是样本损失值,即训练集S的损失值(S i.i.d. from distribution D \mathscr{D} D),我们使用 L S ( w ) L_{\mathcal{S}}(\boldsymbol{w}) LS(w)来估计 L Q ( w ) L_{\mathscr{Q}}(\boldsymbol{w}) LQ(w)

根据PAC-Bayesian generalization bound定理的非正式形式(严格的定理以及其证明在论文的附录 A.1):

对于任何 ρ > 0 \rho>0 ρ>0 ρ \rho ρ为领域大小),从分布来看,生成的训练集大概率满足:
L D ( w ) ≤ max ⁡ ∥ ϵ ∥ 2 ≤ ρ L S ( w + ϵ ) + h ( ∥ w ∥ 2 2 / ρ 2 ) L_{\mathscr{D}}(\boldsymbol{w}) \leq \max _{\|\epsilon\|_{2} \leq \rho} L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon})+h\left(\|\boldsymbol{w}\|_{2}^{2} / \rho^{2}\right) LD(w)ϵ2ρmaxLS(w+ϵ)+h(w22/ρ2) 其中, h : R + → R + h: \mathbb{R}_{+} \rightarrow \mathbb{R}_{+} h:R+R+是单调递增函数。附录 A.1结尾给出了具体这个 h h h函数是什么。
L D ( w ) ≤ max ⁡ ∥ ϵ ∥ 2 ≤ ρ L S ( w + ϵ ) + + k log ⁡ ( 1 + ∥ w ∥ 2 2 ρ 2 ( 1 + log ⁡ ( n ) k ) 2 ) + 4 log ⁡ n δ + 8 log ⁡ ( 6 n + 3 k ) n − 1 \begin{aligned} L_{\mathscr{D}}(\boldsymbol{w}) & \leq \max _{\|\epsilon\|_{2} \leq \rho} L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon})+\\ &+\sqrt{\frac{k \log \left(1+\frac{\|\boldsymbol{w}\|_{2}^{2}}{\rho^{2}}\left(1+\sqrt{\frac{\log (n)}{k}}\right)^{2}\right)+4 \log \frac{n}{\delta}+8 \log (6 n+3 k)}{n-1}} \end{aligned} LD(w)ϵ2ρmaxLS(w+ϵ)++n1klog(1+ρ2w22(1+klog(n) )2)+4logδn+8log(6n+3k)
那么也就是说,为了得到损失函数 L Q ( w ) L_{\mathscr{Q}}(\boldsymbol{w}) LQ(w)的极小解,现在要求:不等式右边式子,即 max ⁡ ∥ ϵ ∥ 2 ≤ ρ L S ( w + ϵ ) + h ( ∥ w ∥ 2 2 / ρ 2 ) \max _{\|\epsilon\|_{2} \leq \rho} L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon})+h\left(\|\boldsymbol{w}\|_{2}^{2} / \rho^{2}\right) maxϵ2ρLS(w+ϵ)+h(w22/ρ2)的最小值,也就是一个极小-极大化minimax问题。

(2)引出锐度sharpness概念

这里把不等式的右边展开为:
[ max ⁡ ∥ ϵ ∥ 2 ≤ ρ L S ( w + ϵ ) − L S ( w ) ] + L S ( w ) + h ( ∥ w ∥ 2 2 / ρ 2 ) \left[\max _{\|\epsilon\|_{2} \leq \rho} L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon})-L_{\mathcal{S}}(\boldsymbol{w})\right]+L_{\mathcal{S}}(\boldsymbol{w})+h\left(\|\boldsymbol{w}\|_{2}^{2} / \rho^{2}\right) [ϵ2ρmaxLS(w+ϵ)LS(w)]+LS(w)+h(w22/ρ2)
这个中括号里的就是 L S L_{\mathcal{S}} LS锐度sharpness,可以衡量从 w \boldsymbol{w} w移动到相临近的参数值,training loss增加的速度。这个式子就变成sharpness + training loss+ w w w的正则项了。

(3)进一步对目标函数进行变换

考虑到具体的函数 h h h受证明细节的影响严重,把 h ( ∥ w ∥ 2 2 / ρ 2 ) h\left(\|\boldsymbol{w}\|_{2}^{2} / \rho^{2}\right) h(w22/ρ2)替换成 λ ∥ w ∥ 2 2 \lambda\|w\|_{2}^{2} λw22,这样就产生了一个标准的L2正则项了, λ ∥ w ∥ 2 2 \lambda\|w\|_{2}^{2} λw22 w w w的正则项, λ \lambda λ即为weight decay。

综合以上,现在的目标函数可以变成(公式0):
min ⁡ w L S S A M ( w ) + λ ∥ w ∥ 2 2  where  L S S A M ( w ) ≜ max ⁡ ∥ ϵ ∥ p ≤ ρ L S ( w + ϵ ) \min _{\boldsymbol{w}} L_{\mathcal{S}}^{S A M}(\boldsymbol{w})+\lambda\|\boldsymbol{w}\|_{2}^{2} \quad \text { where } \quad L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) \triangleq \max _{\|\boldsymbol{\epsilon}\|_{p} \leq \rho} L_{S}(\boldsymbol{w}+\boldsymbol{\epsilon}) wminLSSAM(w)+λw22 where LSSAM(w)ϵpρmaxLS(w+ϵ)把它叫做Sharpness-Aware Minimization (SAM) problem。这里, ρ ≥ 0 \rho \geq 0 ρ0是一个超参数,以及 p ∈ [ 1 , ∞ ] p \in[1, \infty] p[1,],论文说 p = 2 p =2 p=2通常是最优的(附录C.5)。

2.怎么解决?

为了最小化 L S S A M ( w ) L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) LSSAM(w),论文提出可以通过对inner maximization求微分来得出 ∇ w L S S A M ( w ) \nabla_{\boldsymbol{w}} L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) wLSSAM(w)的近似值。

可是这个the inner maximization是什么东西呢?

the inner problem,指的是在固定损失函数后,如何在训练集上更新模型超参数使得损失函数值最小,得到当前最优的模型参数的问题。对应的数学表达式就是[超参数] = arg min[损失函数]。

inner maximization其实就是上面的使得损失函数值最小改成最大啦,用数学语言表示就是(公式1):
ϵ ∗ ( w ) ≜ arg ⁡ max ⁡ ∥ ϵ ∥ p ≤ ρ L S ( w + ϵ ) \boldsymbol{\epsilon}^{*}(\boldsymbol{w}) \triangleq \underset{\|\boldsymbol{\epsilon}\|_{p} \leq \rho}{\arg \max } L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon}) ϵ(w)ϵpρargmaxLS(w+ϵ)意思是:在 w w w ϵ \epsilon ϵ邻域 [ a − ϵ , a + ϵ ] [\mathrm{a}-\epsilon, \mathrm{a}+\epsilon] [aϵ,a+ϵ]中,使得目标函数 L S ( w + ϵ ) L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon}) LS(w+ϵ)最大的参数 ϵ \epsilon ϵ

然后,y用在 x = w + ϵ x=w+\epsilon x=w+ϵ处, ϵ → 0 \epsilon \to 0 ϵ0一阶泰勒展开式去逼近一个the inner maximization的目标函数:

a. x = x 0 x=x_{0} x=x0处的一阶泰勒展开式:
f ( x ) = f ( x 0 ) + f ′ ( x 0 ) ( x − x 0 ) + o ( x − x 0 ) f(x)=f\left(x_{0}\right)+f^{\prime}\left(x_{0}\right)\left(x-x_{0}\right)+o\left(x-x_{0}\right) f(x)=f(x0)+f(x0)(xx0)+o(xx0)

b. x = w + ϵ x=w+\epsilon x=w+ϵ处, ϵ → 0 \epsilon \to 0 ϵ0的一阶泰勒展开式:
L S ( w + ϵ ) = L S ( w ) + ϵ T ∇ w L S ( w ) + o ( ϵ ) L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon})=L_{\mathcal{S}}(\boldsymbol{w})+ \epsilon^{T}\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})+o\left(\epsilon\right) LS(w+ϵ)=LS(w)+ϵTwLS(w)+o(ϵ) L S ( w + ϵ ) ≈ L S ( w ) + ϵ T ∇ w L S ( w ) L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon}) \approx L_{\mathcal{S}}(\boldsymbol{w})+ \epsilon^{T}\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w}) LS(w+ϵ)LS(w)+ϵTwLS(w)

c. 又因为 arg ⁡ max ⁡ ∥ ϵ ∥ p ≤ ρ L S ( w ) = 0 \underset{\|\epsilon\|_{p} \leq \rho}{\arg \max } L_{\mathcal{S}}(\boldsymbol{w}) =0 ϵpρargmaxLS(w)=0
把b、c公式代入公式1,得到:
ϵ ∗ ( w ) ≜ arg ⁡ max ⁡ ∥ ϵ ∥ p ≤ ρ L S ( w + ϵ ) ≈ arg ⁡ max ⁡ ∥ ϵ ∥ p ≤ ρ L S ( w ) + ϵ T ∇ w L S ( w ) = arg ⁡ max ⁡ ∥ ϵ ∥ p ≤ ρ ϵ T ∇ w L S ( w ) \begin{aligned} \boldsymbol{\epsilon}^{*}(\boldsymbol{w}) &\triangleq \underset{\|\boldsymbol{\epsilon}\|_{p} \leq \rho}{\arg \max } L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon}) \\ & \approx \underset{\|\epsilon\|_{p} \leq \rho}{\arg \max } L_{\mathcal{S}}(\boldsymbol{w})+\boldsymbol{\epsilon}^{T} \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w}) \\ &=\underset{\|\boldsymbol{\epsilon}\|_{p} \leq \rho}{\arg \max } \boldsymbol{\epsilon}^{T} \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w}) \end{aligned} ϵ(w)ϵpρargmaxLS(w+ϵ)ϵpρargmaxLS(w)+ϵTwLS(w)=ϵpρargmaxϵTwLS(w)

现在,用 ϵ ^ ( w ) \hat{\boldsymbol{\epsilon}}(\boldsymbol{w}) ϵ^(w)表示 ϵ ( w ) \boldsymbol{\epsilon}(\boldsymbol{w}) ϵ(w)的近似值, ϵ ^ ( w ) \hat{\boldsymbol{\epsilon}}(\boldsymbol{w}) ϵ^(w)可以用经典对偶范数问题(dual
norm problem)的解法来得出:
ϵ ^ ( w ) = ρ sign ⁡ ( ∇ w L S ( w ) ) ∣ ∇ w L S ( w ) ∣ q − 1 / ( ∥ ∇ w L S ( w ) ∥ q q ) 1 / p \hat{\boldsymbol{\epsilon}}(\boldsymbol{w})=\rho \operatorname{sign}\left(\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right)\left|\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right|^{q-1} /\left(\left\|\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right\|_{q}^{q}\right)^{1 / p} ϵ^(w)=ρsign(wLS(w))wLS(w)q1/(wLS(w)qq)1/p

其中, 1 / p + 1 / q = 1 1 / p+1 / q=1 1/p+1/q=1

之前说过论文提到 p = 2 p =2 p=2通常效果最优,这里就把 p = 2 p =2 p=2代入到上式中, 1 / p = 1 / 2 1/p = 1/2 1/p=1/2, q − 1 = 1 q-1=1 q1=1,则得到公式2
ϵ ^ ( w ) = ρ sign ⁡ ( ∇ w L S ( w ) ) ∣ ∇ w L S ( w ) ∣ ∥ ∇ w L S ( w ) ∥ 2 = ρ ∇ L S ( w ) ∥ ∇ L S ( w ) ∥ 2 \begin{aligned} \hat{\boldsymbol{\epsilon}}(\boldsymbol{w}) &=\rho \operatorname{sign}\left(\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right) \frac{|\nabla_{\boldsymbol{w}} L_{S}\left(\mathbf{\boldsymbol{w}}\right)|}{\left\|\nabla_{\boldsymbol{w}} L_{S}\left(\mathbf{\boldsymbol{w}}\right)\right\|_{2}} \\ & = \rho \frac{\nabla L_{S}\left(\mathbf{w}\right)}{\left\|\nabla L_{S}\left(\mathbf{w}\right)\right\|_{2}} \end{aligned} ϵ^(w)=ρsign(wLS(w))wLS(w)2wLS(w)=ρLS(w)2LS(w)
把上式代入到公式0,得到:
∇ w L S S A M ( w ) ≈ ∇ w L S ( w + ϵ ^ ( w ) ) = d ( w + ϵ ^ ( w ) ) d w ∇ w L S ( w ) ∣ w + ϵ ^ ( w ) = ∇ w L S ( w ) ∣ w + ϵ ^ ( w ) + d ϵ ^ ( w ) d w ∇ w L S ( w ) ∣ w + ϵ ^ ( w ) \begin{aligned} \nabla_{\boldsymbol{w}} L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) & \approx \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w}+\hat{\boldsymbol{\epsilon}}(\boldsymbol{w})) \\ & =\left.\frac{d(\boldsymbol{w}+\hat{\boldsymbol{\epsilon}}(\boldsymbol{w}))}{d \boldsymbol{w}} \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right|_{\boldsymbol{w}+\hat{\epsilon}(w)} \\ &=\left.\nabla_{w} L_{\mathcal{S}}(\boldsymbol{w})\right|_{\boldsymbol{w}+\hat{\epsilon}(\boldsymbol{w})}+\left.\frac{d \hat{\epsilon}(\boldsymbol{w})}{d \boldsymbol{w}} \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right|_{\boldsymbol{w}+\hat{\epsilon}(\boldsymbol{w})} \end{aligned} wLSSAM(w)wLS(w+ϵ^(w))=dwd(w+ϵ^(w))wLS(w)w+ϵ^(w)=wLS(w)w+ϵ^(w)+dwdϵ^(w)wLS(w)w+ϵ^(w)
上面第二行用到了复合微分:
d f ( g ( x ) ) d x = d g ( x ) d x d f ( x ) ∣ g ( x ) \begin{aligned} &\frac {d f(g(x))}{dx} =\left.\frac{d g(x)}{d x} d f(x)\right|_{g(x)} \end{aligned} dxdf(g(x))=dxdg(x)df(x)g(x)
公式3
∇ w L S S A M ( w ) ≈ ∇ w L S ( w ) ∣ w + ϵ ^ ( w ) \left.\nabla_{\boldsymbol{w}} L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) \approx \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(w)\right|_{\boldsymbol{w}+\hat{\boldsymbol{\epsilon}}(\boldsymbol{w})} wLSSAM(w)wLS(w)w+ϵ^(w)
如Section 3中的结果所示,如果没有二阶项,这个近似值会产生一个有效的算法。在论文附录C.4中具体研究了这种影响,在开始的实验中,加入二阶项会令人惊讶地降低性能,进一步研究这些条款的影响应该是未来工作的优先事项。

我们通过将标准的数值优化器(如随机梯度下降SGD)应用于SAM目标函数 L S S A M ( w ) L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) LSSAM(w),其中使用公式3计算目标函数梯度,从而获得最终的SAM算法。

论文的Algorithm 1给出了SAM算法的伪代码,它使用SGD作为基本优化器。右边的Figure 2示意性地说明了SAM参数的单次迭代更新。

SAM解析:Sharpness-Aware Minimization for Efficiently Improving Generalization_第1张图片

你可能感兴趣的:(一起来读论文,概率论,深度学习,机器学习)