论文:Sharpness-Aware Minimization for Efficiently Improving Generalization( ICLR 2021)
综合了另一篇论文:ASAM: Adaptive Sharpness-Aware Minimization
for Scale-Invariant Learning of Deep Neural Networks 对理论部分这边的解释,同时这篇论文自己也对SAM做出了改进。
L Q ( w ) L_{\mathscr{Q}}(\boldsymbol{w}) LQ(w)是总体损失值,而训练集 L S ( w ) L_{\mathcal{S}}(\boldsymbol{w}) LS(w)是样本损失值,即训练集S的损失值(S i.i.d. from distribution D \mathscr{D} D),我们使用 L S ( w ) L_{\mathcal{S}}(\boldsymbol{w}) LS(w)来估计 L Q ( w ) L_{\mathscr{Q}}(\boldsymbol{w}) LQ(w)。
根据PAC-Bayesian generalization bound定理的非正式形式(严格的定理以及其证明在论文的附录 A.1):
对于任何 ρ > 0 \rho>0 ρ>0( ρ \rho ρ为领域大小),从分布来看,生成的训练集大概率满足:
L D ( w ) ≤ max ∥ ϵ ∥ 2 ≤ ρ L S ( w + ϵ ) + h ( ∥ w ∥ 2 2 / ρ 2 ) L_{\mathscr{D}}(\boldsymbol{w}) \leq \max _{\|\epsilon\|_{2} \leq \rho} L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon})+h\left(\|\boldsymbol{w}\|_{2}^{2} / \rho^{2}\right) LD(w)≤∥ϵ∥2≤ρmaxLS(w+ϵ)+h(∥w∥22/ρ2) 其中, h : R + → R + h: \mathbb{R}_{+} \rightarrow \mathbb{R}_{+} h:R+→R+是单调递增函数。附录 A.1结尾给出了具体这个 h h h函数是什么。
L D ( w ) ≤ max ∥ ϵ ∥ 2 ≤ ρ L S ( w + ϵ ) + + k log ( 1 + ∥ w ∥ 2 2 ρ 2 ( 1 + log ( n ) k ) 2 ) + 4 log n δ + 8 log ( 6 n + 3 k ) n − 1 \begin{aligned} L_{\mathscr{D}}(\boldsymbol{w}) & \leq \max _{\|\epsilon\|_{2} \leq \rho} L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon})+\\ &+\sqrt{\frac{k \log \left(1+\frac{\|\boldsymbol{w}\|_{2}^{2}}{\rho^{2}}\left(1+\sqrt{\frac{\log (n)}{k}}\right)^{2}\right)+4 \log \frac{n}{\delta}+8 \log (6 n+3 k)}{n-1}} \end{aligned} LD(w)≤∥ϵ∥2≤ρmaxLS(w+ϵ)++n−1klog(1+ρ2∥w∥22(1+klog(n))2)+4logδn+8log(6n+3k)
那么也就是说,为了得到损失函数 L Q ( w ) L_{\mathscr{Q}}(\boldsymbol{w}) LQ(w)的极小解,现在要求:不等式右边式子,即 max ∥ ϵ ∥ 2 ≤ ρ L S ( w + ϵ ) + h ( ∥ w ∥ 2 2 / ρ 2 ) \max _{\|\epsilon\|_{2} \leq \rho} L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon})+h\left(\|\boldsymbol{w}\|_{2}^{2} / \rho^{2}\right) max∥ϵ∥2≤ρLS(w+ϵ)+h(∥w∥22/ρ2)的最小值,也就是一个极小-极大化minimax问题。
这里把不等式的右边展开为:
[ max ∥ ϵ ∥ 2 ≤ ρ L S ( w + ϵ ) − L S ( w ) ] + L S ( w ) + h ( ∥ w ∥ 2 2 / ρ 2 ) \left[\max _{\|\epsilon\|_{2} \leq \rho} L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon})-L_{\mathcal{S}}(\boldsymbol{w})\right]+L_{\mathcal{S}}(\boldsymbol{w})+h\left(\|\boldsymbol{w}\|_{2}^{2} / \rho^{2}\right) [∥ϵ∥2≤ρmaxLS(w+ϵ)−LS(w)]+LS(w)+h(∥w∥22/ρ2)
这个中括号里的就是 L S L_{\mathcal{S}} LS的锐度sharpness,可以衡量从 w \boldsymbol{w} w移动到相临近的参数值,training loss增加的速度。这个式子就变成sharpness + training loss+ w w w的正则项了。
考虑到具体的函数 h h h受证明细节的影响严重,把 h ( ∥ w ∥ 2 2 / ρ 2 ) h\left(\|\boldsymbol{w}\|_{2}^{2} / \rho^{2}\right) h(∥w∥22/ρ2)替换成 λ ∥ w ∥ 2 2 \lambda\|w\|_{2}^{2} λ∥w∥22,这样就产生了一个标准的L2正则项了, λ ∥ w ∥ 2 2 \lambda\|w\|_{2}^{2} λ∥w∥22为 w w w的正则项, λ \lambda λ即为weight decay。
综合以上,现在的目标函数可以变成(公式0):
min w L S S A M ( w ) + λ ∥ w ∥ 2 2 where L S S A M ( w ) ≜ max ∥ ϵ ∥ p ≤ ρ L S ( w + ϵ ) \min _{\boldsymbol{w}} L_{\mathcal{S}}^{S A M}(\boldsymbol{w})+\lambda\|\boldsymbol{w}\|_{2}^{2} \quad \text { where } \quad L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) \triangleq \max _{\|\boldsymbol{\epsilon}\|_{p} \leq \rho} L_{S}(\boldsymbol{w}+\boldsymbol{\epsilon}) wminLSSAM(w)+λ∥w∥22 where LSSAM(w)≜∥ϵ∥p≤ρmaxLS(w+ϵ)把它叫做Sharpness-Aware Minimization (SAM) problem。这里, ρ ≥ 0 \rho \geq 0 ρ≥0是一个超参数,以及 p ∈ [ 1 , ∞ ] p \in[1, \infty] p∈[1,∞],论文说 p = 2 p =2 p=2通常是最优的(附录C.5)。
为了最小化 L S S A M ( w ) L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) LSSAM(w),论文提出可以通过对inner maximization求微分来得出 ∇ w L S S A M ( w ) \nabla_{\boldsymbol{w}} L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) ∇wLSSAM(w)的近似值。
可是这个the inner maximization是什么东西呢?
the inner problem,指的是在固定损失函数后,如何在训练集上更新模型超参数使得损失函数值最小,得到当前最优的模型参数的问题。对应的数学表达式就是[超参数] = arg min[损失函数]。
inner maximization其实就是上面的使得损失函数值最小改成最大啦,用数学语言表示就是(公式1):
ϵ ∗ ( w ) ≜ arg max ∥ ϵ ∥ p ≤ ρ L S ( w + ϵ ) \boldsymbol{\epsilon}^{*}(\boldsymbol{w}) \triangleq \underset{\|\boldsymbol{\epsilon}\|_{p} \leq \rho}{\arg \max } L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon}) ϵ∗(w)≜∥ϵ∥p≤ρargmaxLS(w+ϵ)意思是:在 w w w的 ϵ \epsilon ϵ邻域 [ a − ϵ , a + ϵ ] [\mathrm{a}-\epsilon, \mathrm{a}+\epsilon] [a−ϵ,a+ϵ]中,使得目标函数 L S ( w + ϵ ) L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon}) LS(w+ϵ)最大的参数 ϵ \epsilon ϵ。
然后,y用在 x = w + ϵ x=w+\epsilon x=w+ϵ处, ϵ → 0 \epsilon \to 0 ϵ→0的一阶泰勒展开式去逼近一个the inner maximization的目标函数:
a. 在 x = x 0 x=x_{0} x=x0处的一阶泰勒展开式:
f ( x ) = f ( x 0 ) + f ′ ( x 0 ) ( x − x 0 ) + o ( x − x 0 ) f(x)=f\left(x_{0}\right)+f^{\prime}\left(x_{0}\right)\left(x-x_{0}\right)+o\left(x-x_{0}\right) f(x)=f(x0)+f′(x0)(x−x0)+o(x−x0)
b. 在 x = w + ϵ x=w+\epsilon x=w+ϵ处, ϵ → 0 \epsilon \to 0 ϵ→0的一阶泰勒展开式:
L S ( w + ϵ ) = L S ( w ) + ϵ T ∇ w L S ( w ) + o ( ϵ ) L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon})=L_{\mathcal{S}}(\boldsymbol{w})+ \epsilon^{T}\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})+o\left(\epsilon\right) LS(w+ϵ)=LS(w)+ϵT∇wLS(w)+o(ϵ) L S ( w + ϵ ) ≈ L S ( w ) + ϵ T ∇ w L S ( w ) L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon}) \approx L_{\mathcal{S}}(\boldsymbol{w})+ \epsilon^{T}\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w}) LS(w+ϵ)≈LS(w)+ϵT∇wLS(w)
c. 又因为 arg max ∥ ϵ ∥ p ≤ ρ L S ( w ) = 0 \underset{\|\epsilon\|_{p} \leq \rho}{\arg \max } L_{\mathcal{S}}(\boldsymbol{w}) =0 ∥ϵ∥p≤ρargmaxLS(w)=0
把b、c公式代入公式1,得到:
ϵ ∗ ( w ) ≜ arg max ∥ ϵ ∥ p ≤ ρ L S ( w + ϵ ) ≈ arg max ∥ ϵ ∥ p ≤ ρ L S ( w ) + ϵ T ∇ w L S ( w ) = arg max ∥ ϵ ∥ p ≤ ρ ϵ T ∇ w L S ( w ) \begin{aligned} \boldsymbol{\epsilon}^{*}(\boldsymbol{w}) &\triangleq \underset{\|\boldsymbol{\epsilon}\|_{p} \leq \rho}{\arg \max } L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon}) \\ & \approx \underset{\|\epsilon\|_{p} \leq \rho}{\arg \max } L_{\mathcal{S}}(\boldsymbol{w})+\boldsymbol{\epsilon}^{T} \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w}) \\ &=\underset{\|\boldsymbol{\epsilon}\|_{p} \leq \rho}{\arg \max } \boldsymbol{\epsilon}^{T} \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w}) \end{aligned} ϵ∗(w)≜∥ϵ∥p≤ρargmaxLS(w+ϵ)≈∥ϵ∥p≤ρargmaxLS(w)+ϵT∇wLS(w)=∥ϵ∥p≤ρargmaxϵT∇wLS(w)
现在,用 ϵ ^ ( w ) \hat{\boldsymbol{\epsilon}}(\boldsymbol{w}) ϵ^(w)表示 ϵ ( w ) \boldsymbol{\epsilon}(\boldsymbol{w}) ϵ(w)的近似值, ϵ ^ ( w ) \hat{\boldsymbol{\epsilon}}(\boldsymbol{w}) ϵ^(w)可以用经典对偶范数问题(dual
norm problem)的解法来得出:
ϵ ^ ( w ) = ρ sign ( ∇ w L S ( w ) ) ∣ ∇ w L S ( w ) ∣ q − 1 / ( ∥ ∇ w L S ( w ) ∥ q q ) 1 / p \hat{\boldsymbol{\epsilon}}(\boldsymbol{w})=\rho \operatorname{sign}\left(\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right)\left|\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right|^{q-1} /\left(\left\|\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right\|_{q}^{q}\right)^{1 / p} ϵ^(w)=ρsign(∇wLS(w))∣∇wLS(w)∣q−1/(∥∇wLS(w)∥qq)1/p
其中, 1 / p + 1 / q = 1 1 / p+1 / q=1 1/p+1/q=1。
之前说过论文提到 p = 2 p =2 p=2通常效果最优,这里就把 p = 2 p =2 p=2代入到上式中, 1 / p = 1 / 2 1/p = 1/2 1/p=1/2, q − 1 = 1 q-1=1 q−1=1,则得到公式2:
ϵ ^ ( w ) = ρ sign ( ∇ w L S ( w ) ) ∣ ∇ w L S ( w ) ∣ ∥ ∇ w L S ( w ) ∥ 2 = ρ ∇ L S ( w ) ∥ ∇ L S ( w ) ∥ 2 \begin{aligned} \hat{\boldsymbol{\epsilon}}(\boldsymbol{w}) &=\rho \operatorname{sign}\left(\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right) \frac{|\nabla_{\boldsymbol{w}} L_{S}\left(\mathbf{\boldsymbol{w}}\right)|}{\left\|\nabla_{\boldsymbol{w}} L_{S}\left(\mathbf{\boldsymbol{w}}\right)\right\|_{2}} \\ & = \rho \frac{\nabla L_{S}\left(\mathbf{w}\right)}{\left\|\nabla L_{S}\left(\mathbf{w}\right)\right\|_{2}} \end{aligned} ϵ^(w)=ρsign(∇wLS(w))∥∇wLS(w)∥2∣∇wLS(w)∣=ρ∥∇LS(w)∥2∇LS(w)
把上式代入到公式0,得到:
∇ w L S S A M ( w ) ≈ ∇ w L S ( w + ϵ ^ ( w ) ) = d ( w + ϵ ^ ( w ) ) d w ∇ w L S ( w ) ∣ w + ϵ ^ ( w ) = ∇ w L S ( w ) ∣ w + ϵ ^ ( w ) + d ϵ ^ ( w ) d w ∇ w L S ( w ) ∣ w + ϵ ^ ( w ) \begin{aligned} \nabla_{\boldsymbol{w}} L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) & \approx \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w}+\hat{\boldsymbol{\epsilon}}(\boldsymbol{w})) \\ & =\left.\frac{d(\boldsymbol{w}+\hat{\boldsymbol{\epsilon}}(\boldsymbol{w}))}{d \boldsymbol{w}} \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right|_{\boldsymbol{w}+\hat{\epsilon}(w)} \\ &=\left.\nabla_{w} L_{\mathcal{S}}(\boldsymbol{w})\right|_{\boldsymbol{w}+\hat{\epsilon}(\boldsymbol{w})}+\left.\frac{d \hat{\epsilon}(\boldsymbol{w})}{d \boldsymbol{w}} \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right|_{\boldsymbol{w}+\hat{\epsilon}(\boldsymbol{w})} \end{aligned} ∇wLSSAM(w)≈∇wLS(w+ϵ^(w))=dwd(w+ϵ^(w))∇wLS(w)∣∣∣∣w+ϵ^(w)=∇wLS(w)∣w+ϵ^(w)+dwdϵ^(w)∇wLS(w)∣∣∣∣w+ϵ^(w)
上面第二行用到了复合微分:
d f ( g ( x ) ) d x = d g ( x ) d x d f ( x ) ∣ g ( x ) \begin{aligned} &\frac {d f(g(x))}{dx} =\left.\frac{d g(x)}{d x} d f(x)\right|_{g(x)} \end{aligned} dxdf(g(x))=dxdg(x)df(x)∣∣∣∣g(x)
公式3:
∇ w L S S A M ( w ) ≈ ∇ w L S ( w ) ∣ w + ϵ ^ ( w ) \left.\nabla_{\boldsymbol{w}} L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) \approx \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(w)\right|_{\boldsymbol{w}+\hat{\boldsymbol{\epsilon}}(\boldsymbol{w})} ∇wLSSAM(w)≈∇wLS(w)∣∣w+ϵ^(w)
如Section 3中的结果所示,如果没有二阶项,这个近似值会产生一个有效的算法。在论文附录C.4中具体研究了这种影响,在开始的实验中,加入二阶项会令人惊讶地降低性能,进一步研究这些条款的影响应该是未来工作的优先事项。
我们通过将标准的数值优化器(如随机梯度下降SGD)应用于SAM目标函数 L S S A M ( w ) L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) LSSAM(w),其中使用公式3计算目标函数梯度,从而获得最终的SAM算法。
论文的Algorithm 1给出了SAM算法的伪代码,它使用SGD作为基本优化器。右边的Figure 2示意性地说明了SAM参数的单次迭代更新。