谱范数正则(Spectral Norm Regularization)的理解

近来,DeepMind的一篇论文《LARGE SCALE GAN TRAINING FOR
HIGH FIDELITY NATURAL IMAGE SYNTHESIS》(arXiv:1809.11096v1)[1](通过大规模Gan训练,得到高精度的合成自然图像)引起了广泛的关注。其中,为保证其大批次(batch够大)Gan训练的稳定性,[1]引入了谱范数正则技术(Spectral Norm Regularization)。该技术从每层神经网络的参数矩阵的谱范数角度,引入正则约束,使神经网络对输入扰动具有较好的非敏感性,从而使训练过程更稳定,更容易收敛。
谱范数正则(Spectral Norm Regularization,简称为SNR)最早来自于2017年5月日本国立信息研究所Yoshida的一篇论文[2],他们后续又于2018年2月再再arXiv发了一篇SNR用于Gan的论文[3],以阐明SNR的有效性。因为当SGD(统计梯度下降)的批次(Batch size)一大的时候,其泛化性能却会降低,SNR能有效地解决这一问题。

SNR的讨论是从网络的泛化((Generalizability))开始的。对于Deep Learning而言,泛化是一个重要的性能指标,直觉上它与扰动(Perturbation)的影响有关。我们可以这样理解:局部最小点附近如果是平坦(flatness)的话,那么其泛化的性能将较好,反之,若是不平坦(sharpness)的话,稍微一点变动,将产生较大变化,则其泛化性能就不好。因此,我们可以从网络对抗扰动的性能入手来提升网络的泛化能力。

一、扰动的表示

对应多层神经网络而言,扰动(Perturbation)的来源主要有两个:1)参数的扰动;2)输入的扰动。[2]是从输入扰动的角度来进行讨论的。假设一个前馈网络的第 l l l 层有如下关系:
x l = f l ( W l x l − 1 + b l ) ( 1 ) \mathbf x^l=f^l(W^l\mathbf x^{l-1}+\mathbf b^l)\qquad(1) xl=fl(Wlxl1+bl)(1)
(1)中, x l \mathbf x^l xl 表示第 l l l 层的输出, x l − 1 \mathbf x^{l-1} xl1 表示第 l l l 层的输入, W l , b l W^l,\mathbf b^l Wl,bl 分别表示该层神经网络的参数矩阵和偏置向量, f l ( ⋅ ) f^l(\cdot) fl() 表示网络的非线性激活函数, l = 1 , ⋯   , L l=1,\cdots,L l=1,,L 即整个网络有L层。于是,整个网络的参数集合可用 Θ = { W l , b l } l = 1 L \Theta = \{ W^l,\mathbf b^l\}^L_{l=1} Θ={Wl,bl}l=1L 表示。
对于给定训练集: ( x i , y i ) i = 1 K (\mathbf x_i, \mathbf y_i)^K_{i=1} (xi,yi)i=1K,其中 x i ∈ R n 0 , y i ∈ R n L \mathbf x_i \in \mathbb R^{n_0},\mathbf y_i \in \mathbb R^{n_L} xiRn0,yiRnL,则Loss 函数可以表示为:
L o s s = 1 K ∑ i = 1 K L ( f Θ ( x i ) , y i ) ( 2 ) Loss=\frac{1}{K}\sum^K_{i=1}L(f_{\Theta}(\mathbf x_i),\mathbf y_i)\qquad(2) Loss=K1i=1KL(fΘ(xi),yi)(2)
其中, L ( ⋅ ) L(\cdot) L() 表示我们常用的优化目标函数,如:交叉熵用于分类(Classification)任务、最小平方差 l 2 l_2 l2用于回归(Regression)任务。
所谓输入扰动,就指:输入有一个很小的变化,引起的输出变化:
x → x + ξ f ( x ) → f ( x + ξ ) So we define: P = ∥ f ( x + ξ ) − f ( x ) ∥ ∥ ξ ∥ ( 3 ) \mathbf x\rightarrow \mathbf x+\mathbf \xi \\ f(\mathbf x) \rightarrow f(\mathbf x +\mathbf \xi )\\ \text{So we define:}\\ P=\frac{\Vert f(\mathbf x +\mathbf \xi )-f(\mathbf x)\Vert}{\Vert \mathbf \xi \Vert} \qquad(3) xx+ξf(x)f(x+ξ)So we define:P=ξf(x+ξ)f(x)(3)
我们要考察输入扰动的影响,可通过扰动指数—— P P P,定量分析。对于多层神经网络,其非线性的引入是由于非线性激活函数。对于常见的非线性函数,如:ReLU、maxout、maxpooling等,我们可以将它看作是分段线性函数,因此,对于 x \mathbf x x 的邻域来说,可看成是线性函数,如:ReLu。输入扰动发生在 x \mathbf x x 的邻域中,对于单层神经网络(未经激活函数)有以下关系:
∥ f ( x + ξ ) − f ( x ) ∥ ∥ ξ ∥ = ∥ W Θ , x ( x + ξ ) + b Θ , x − W Θ , x x − b Θ , x ∥ ∥ ξ ∥ = ∥ W Θ , x ξ ∥ ∥ ξ ∥ ≤ σ ( W Θ , x ) ( 4 ) \frac{\Vert f(\mathbf x +\mathbf \xi )-f(\mathbf x)\Vert}{\Vert \mathbf \xi \Vert} = \frac{\Vert W_{\Theta,x}(\mathbf x +\mathbf \xi )+\mathbf b_{\Theta,x}-W_{\Theta,x}\mathbf x -\mathbf b_{\Theta,x}\Vert}{\Vert \mathbf \xi \Vert}\\ =\frac{\Vert W_{\Theta,x} \xi \Vert}{\Vert \mathbf \xi \Vert} \le \sigma(W_{\Theta,x}) \qquad(4) ξf(x+ξ)f(x)=ξWΘ,x(x+ξ)+bΘ,xWΘ,xxbΘ,x=ξWΘ,xξσ(WΘ,x)(4)
其中, σ ( W Θ , x ) \sigma(W_{\Theta,x}) σ(WΘ,x) 是矩阵 W Θ , x W_{\Theta,x} WΘ,x 的谱范数,谱范数的定义为:
 A is a matrix,  A ∈ R m × n σ ( A ) = m a x ξ ∈ R n × 1 , ξ ≠ 0 ∥ A ξ ∥ 2 ∥ ξ ∥ 2 ( 5 ) \text{ A is a matrix, } A \in \mathbb R^{m\times n}\\ \sigma(A) = max_{\xi \in R^{n\times1},\xi \neq0} \frac{\Vert A \xi \Vert_2}{\Vert\xi\Vert_2} \qquad(5)  A is a matrix, ARm×nσ(A)=maxξRn×1,ξ̸=0ξ2Aξ2(5)
所谓谱范数,就是它所对应矩阵 A A A 的最大奇异值(Singular Value)。
若选择网络的激活函数为ReLU,函数的作用相当于一个对角矩阵,其对角元素在输入为正时,等于1;输入为负时,等于0。于是,第 l l l 层的激活函数可表示为对角矩阵: D Θ , x l ∈ R n l × n l D_{\Theta,x}^l \in \mathbb R^{n^l\times n^l} DΘ,xlRnl×nl。由此,多层网络映射可表示为矩阵相乘,于是有:
y = W Θ , x x . W Θ , x = D Θ , x L W L D Θ , x L − 1 W L − 1 ⋯ D Θ , x 1 W 1 ( 6 ) \mathbf y = W_{\Theta,x} \mathbf x \\ . \\ W_{\Theta,x}=D_{\Theta,x}^L W^L D_{\Theta,x}^{L-1} W^{L-1}\cdots D_{\Theta,x}^1 W^1 \qquad(6) y=WΘ,xx.WΘ,x=DΘ,xLWLDΘ,xL1WL1DΘ,x1W1(6)
因此有:
σ ( W Θ , x ) ≤ σ ( D Θ , x L ) σ ( W Θ , x L ) σ ( D Θ , x L − 1 ) σ ( W Θ , x L − 1 ) ⋯ σ ( D Θ , x 1 ) σ ( W Θ , x 1 ) ≤ ∏ l = 1 L σ ( W l ) ( 7 ) \sigma(W_{\Theta,x} )\le \sigma(D_{\Theta,x}^L)\sigma(W_{\Theta,x}^L)\sigma(D_{\Theta,x}^{L-1})\sigma(W_{\Theta,x}^{L-1})\cdots\sigma(D_{\Theta,x}^1)\sigma(W_{\Theta,x}^1)\le \prod_{l=1}^L \sigma(W^l)\qquad(7) σ(WΘ,x)σ(DΘ,xL)σ(WΘ,xL)σ(DΘ,xL1)σ(WΘ,xL1)σ(DΘ,x1)σ(WΘ,x1)l=1Lσ(Wl)(7)
公式(7)给出了整个神经网络的扰动指数的上限,它是各层子网络谱范数的乘积。为限制扰动带来的影响,可将谱范数作为正则项加在传统的Loss中,于是寻优过程变为:
Θ = arg ⁡ min ⁡ Θ ( 1 K ∑ i = 1 K L ( f Θ ( x i ) , y i ) + λ 2 ∑ i = 1 K σ ( W l ) 2 ) ( 8 ) \Theta = \arg\min_{\Theta}\left(\frac 1 K \sum_{i=1}^K L(f_{\Theta}(\mathbf x_i),\mathbf y_i) + \frac {\lambda}{2} \sum_{i=1}^K \sigma(W^l)^2 \right) \qquad(8) Θ=argΘmin(K1i=1KL(fΘ(xi),yi)+2λi=1Kσ(Wl)2)(8)
(8)式通过惩罚各层的谱范数总和,以实现对整个网络的谱范数的限制。

二、谱范数正则项

在通过SGD(统计梯度下降)的方法求最优值时,需要(8)式对 Θ \Theta Θ 求梯度,在实践时,需要求出各层的最大奇异值,这将涉及大量的计算,我们可以用”幂迭代“法来近似它:
u n ← W v n − 1 v n ← W T u n and  σ ( W l ) = ∥ u ∥ 2 ∥ v ∥ 2 ( 9 ) u_{n} \leftarrow W v_{n-1}\\ v_{n}\leftarrow W^T u_n\\ \text{and } \sigma(W^l) = \frac{\Vert u \Vert_2} {\Vert v \Vert_2} \qquad(9) unWvn1vnWTunand σ(Wl)=v2u2(9)
v 0 v_0 v0 可以是一个随机矢量(比如:高斯矢量),通过迭代,可得到谱范数的近似值。(9)式为什么可以求出谱范数呢?[4]给出了一个推导过程,为本文的完整性,我在此重抄了一次。
A = W T W A=W^TW A=WTW 是一个对称阵,形状为 n × n n\times n n×n,并可对角化,令其特征根为: λ 1 , ⋯   , λ n \lambda_1,\cdots,\lambda_n λ1,,λn,它们对应的归一化特征向量为: η 1 , ⋯   , η n \eta_1,\cdots,\eta_n η1,,ηn,它们相互正交,模为1。这些特征向量构成A的列矢量空间的一个基。令:
u ( 0 ) = c 1 η 1 + ⋯ + c n η n A u ( 0 ) = A ( c 1 η 1 + ⋯ + c n η n ) = c 1 λ 1 η 1 + ⋯ + c n λ n η n A A u ( 0 ) = A A ( c 1 η 1 + ⋯ + c n η n ) = c 1 λ 1 2 η 1 + ⋯ + c n λ n 2 η n ⋯ A r u ( 0 ) = A r ( c 1 η 1 + ⋯ + c n η n ) = c 1 λ 1 r η 1 + ⋯ + c n λ n r η n u^{(0)}=c_1\eta_1+\cdots+c_n\eta_n \\ Au^{(0)}=A(c_1\eta_1+\cdots+c_n\eta_n)=c_1\lambda_1\eta_1+\cdots+c_n\lambda_n\eta_n\\ AAu^{(0)}=AA(c_1\eta_1+\cdots+c_n\eta_n)=c_1\lambda_1^2\eta_1+\cdots+c_n\lambda_n^2\eta_n\\ \cdots \\ A^ru^{(0)}=A^r(c_1\eta_1+\cdots+c_n\eta_n)=c_1\lambda_1^r\eta_1+\cdots+c_n\lambda_n^r\eta_n u(0)=c1η1++cnηnAu(0)=A(c1η1++cnηn)=c1λ1η1++cnλnηnAAu(0)=AA(c1η1++cnηn)=c1λ12η1++cnλn2ηnAru(0)=Ar(c1η1++cnηn)=c1λ1rη1++cnλnrηn
λ 1 \lambda_1 λ1 为最大者,有:
A r u ( 0 ) λ 1 r = c 1 η 1 + ⋯ + c n ( λ n λ 1 ) r η n ∵ λ k λ 1 < 1 , ∴ lim ⁡ r → ∞ A r u ( 0 ) λ 1 r = c 1 η 1 \frac{A^ru^{(0)}}{\lambda_1^r}=c_1\eta_1+\cdots+c_n(\frac {\lambda_n} {\lambda_1})^r\eta_n \\ \because \frac{\lambda_k}{\lambda_1}\lt 1,\therefore \lim_{r\rightarrow\infty}\frac{A^ru^{(0)}}{\lambda_1^r}=c_1\eta_1 λ1rAru(0)=c1η1++cn(λ1λn)rηnλ1λk<1,rlimλ1rAru(0)=c1η1

u = A r u ( 0 ) ∥ A r u ( 0 ) ∥ 2 ,  so . A u = A A r u ( 0 ) ∥ A r u ( 0 ) ∥ 2 ≈ A r + 1 c 1 η 1 ∥ A r u ( 0 ) ∥ 2 = λ 1 η 1 u = \frac{A^ru^{(0)}}{\Vert A^ru^{(0)}\Vert_2},\text{ so}\\ . \\ Au=A\frac{A^ru^{(0)}}{\Vert A^ru^{(0)}\Vert_2}\approx \frac{A^{r+1}c_1\eta_1}{\Vert A^ru^{(0)}\Vert_2}=\lambda_1\eta_1 u=Aru(0)2Aru(0), so.Au=AAru(0)2Aru(0)Aru(0)2Ar+1c1η1=λ1η1
即:当r足够大时, u = η 1 u=\eta_1 u=η1 是最大特征值对应的特征向量。此时, u T A u = λ 1 u^TAu=\lambda_1 uTAu=λ1。以上(9)式所表达的迭代过程就是产生 u u u的过程。
最后,谱正则项的实现算法如下:
谱范数正则(Spectral Norm Regularization)的理解_第1张图片

小结:

谱正则来自于一个朴素的直觉:局部最小值处平坦,则泛化能力强。然后,[2]从前馈网络入手,导出以矩阵相乘形式的近似网络函数,让我们能够用矩阵进行奇异值方法去分析,从而说明局部平坦与奇异值之间的关系,最后,在此基础上给出一个可行的正则项设计。
这个推导过程值得我们去学习。


参考文献:
[1] 《LARGE SCALE GAN TRAINING FOR
HIGH FIDELITY NATURAL IMAGE SYNTHESIS》(arXiv:1809.11096v1)
[2] Spectral Norm Regularization for Improving the Generalizability of Deep Learning, Yuchi Yoshida, National Institute of Informatics, 2017. 5, (arXiv: 1705.10941v1)
[3] Spectral Normalization for Generative Adversarial Networks, Takeru Miyato, Yuchi Yoshida, 2018.2(arXiv: 1802.05957v1)
[4] 苏剑林. (2018, Oct 07). 《深度学习中的Lipschitz约束:泛化与生成模型 》[Blog post]. Retrieved from https://kexue.fm/archives/6051

你可能感兴趣的:(机器学习与神经网络)