【深度学习】神经正切核(NTK)理论

神经正切核理论

​ 本文来自于《Theory of Deep Learning》,主要是对神经正切核(NTK)理论进行介绍。这里主要是补充了一些基本概念以及部分推导过程。作为软件工程出身,数学不是特别好,有些基础知识和推导步骤没办法一次补足。若有机会,后续会逐步补全缺失的部分。

一、基础知识

1. Hoeffding不等式

​ 设 X 1 , … , X n X_1,\dots,X_n X1,,Xn n n n个独立的随机变量,且 X i X_i Xi的边界为 [ a i , b i ] [a_i,b_i] [ai,bi]。令 X ˉ = 1 n ∑ i = 1 n X i \bar{X}=\frac{1}{n}\sum_{i=1}^n X_i Xˉ=n1i=1nXi,则有
P ( ∣ X ˉ − E ( X ˉ ) ∣ ≥ t ) ≤ exp ⁡ ( − 2 n 2 t 2 ∑ i = 1 n ( b i − a i ) 2 ) P(|\bar{X}-E(\bar{X})|\geq t)\leq \exp\Big(-\frac{2n^2t^2}{\sum_{i=1}^n(b_i-a_i)^2}\Big) \\ P(XˉE(Xˉ)t)exp(i=1n(biai)22n2t2)

2. Boole不等式

​ 令 A i A_i Ai表达第 i i i个随机事件,那么有
P ( ∪ i A i ) ≤ ∑ i P ( A i ) P\Big(\cup_i A_i\Big)\leq\sum_i P(A_i) \\ P(iAi)iP(Ai)
即至少一个事件发生的概率不大于单独事件发生概率之和。

3. 核函数与核回归

核函数。 X \mathcal{X} X是输入空间, H \mathcal{H} H是特征空间,若存在一个从 X \mathcal{X} X H \mathcal{H} H的映射

ϕ ( x ) : X → H \phi(\textbf{x}):\mathcal{X}\rightarrow\mathcal{H} \\ ϕ(x):XH
使得对所有的 x , z ∈ X \textbf{x},\textbf{z}\in\mathcal{X} x,zX,函数 k ( x , z ) k(\textbf{x},\textbf{z}) k(x,z)均满足
k ( x , z ) = ⟨ ϕ ( x ) , ϕ ( z ) ⟩ k(\textbf{x},\textbf{z})=\langle \phi(\textbf{x}),\phi(\textbf{z})\rangle \\ k(x,z)=ϕ(x),ϕ(z)⟩
则称 k ( x , z ) k(\textbf{x},\textbf{z}) k(x,z)是核函数, ϕ ( x ) \phi(\textbf{x}) ϕ(x)是映射函数, ⟨ ϕ ( x ) , ϕ ( z ) ⟩ \langle \phi(\textbf{x}),\phi(\textbf{z}) \rangle ϕ(x),ϕ(z)⟩表示 ϕ ( x ) \phi(\textbf{x}) ϕ(x) ϕ ( z ) \phi(\textbf{z}) ϕ(z)的内积。核函数的作用是特征映射后求内积,但是不一定需要显示进行映射。

高斯核是一种常见的核函数,定义为
k ( x , z ) = exp ⁡ ( − γ ∥ x − z ∥ 2 ) k(\textbf{x},\textbf{z})=\exp(-\gamma\parallel \textbf{x}-\textbf{z}\parallel^2) \\ k(x,z)=exp(γxz2)
其可以将特征映射至无穷维,因此
exp ⁡ ( − ∥ x − z ∥ 2 ) = exp ⁡ ( − x ⊤ x − z ⊤ z + 2 x ⊤ z ) = exp ⁡ ( − x ⊤ x ) exp ⁡ ( z ⊤ z ) exp ⁡ ( 2 x ⊤ z ) = exp ⁡ ( − x ⊤ x ) exp ⁡ ( z ⊤ z ) ( ∑ k = 0 ∞ ( 2 x ⊤ z ) k k ! ) = ∑ k = 0 ∞ [ exp ⁡ ( − x ⊤ x ) exp ⁡ ( − z ⊤ z ) 2 k k ! 2 k k ! ( x k ) ⊤ ( z k ) ] = ϕ ( x ) ⊤ ϕ ( z ) \begin{align} \exp(-\parallel \textbf{x}-\textbf{z}\parallel^2)&=\exp(-\textbf{x}^\top\textbf{x}-\textbf{z}^\top\textbf{z}+2\textbf{x}^\top\textbf{z}) \\ &=\exp(-\textbf{x}^\top\textbf{x})\exp(\textbf{z}^\top\textbf{z})\exp(2\textbf{x}^\top\textbf{z}) \\ &=\exp(-\textbf{x}^\top\textbf{x})\exp(\textbf{z}^\top\textbf{z})\Big(\sum_{k=0}^{\infty}\frac{(2\textbf{x}^\top\textbf{z})^k}{k!}\Big) \\ &=\sum_{k=0}^{\infty}\Big[ \exp(-\textbf{x}^\top\textbf{x})\exp(-\textbf{z}^\top \textbf{z})\sqrt{\frac{2^k}{k!}}\sqrt{\frac{2^k}{k!}}(\textbf{x}^k)^\top(\textbf{z}^k) \Big] \\ &=\phi(\textbf{x})^\top\phi(\textbf{z}) \end{align} \\ exp(xz2)=exp(xxzz+2xz)=exp(xx)exp(zz)exp(2xz)=exp(xx)exp(zz)(k=0k!(2xz)k)=k=0[exp(xx)exp(zz)k!2k k!2k (xk)(zk)]=ϕ(x)ϕ(z)
(上式第三等号使用了Taylor展开 exp ⁡ ( 2 x ⊤ z ) = ∑ 0 ∞ ( 2 x ⊤ z ) k k ! \exp(2\textbf{x}^\top\textbf{z})=\sum_{0}^{\infty}\frac{(2\textbf{x}^\top\textbf{z})^k}{k!} exp(2xz)=0k!(2xz)k

基于上式可以得到高斯核的映射函数为
ϕ ( x ) = exp ⁡ ( − x ⊤ x ) ( 1 , 2 1 1 ! x 1 , 2 2 2 ! x 2 , … , 2 k k ! x k , … ) \phi(\textbf{x})=\exp(-\textbf{x}^\top\textbf{x})\Big( 1,\sqrt{\frac{2^1}{1!}}\textbf{x}^1,\sqrt{\frac{2^2}{2!}}\textbf{x}^2,\dots,\sqrt{\frac{2^k}{k!}}\textbf{x}^k,\dots \Big) \\ ϕ(x)=exp(xx)(1,1!21 x1,2!22 x2,,k!2k xk,)

核回归。核回归是经典的非线性回归算法。给定训练集 ( X , y ) = { ( x i , y i ) } i = 1 n (\textbf{X},\textbf{y})=\{(\textbf{x}_i,y_i)\}_{i=1}^n (X,y)={(xi,yi)}i=1n,其中 x i \textbf{x}_i xi是输入数据, y i = f ( x i ) y_i=f(\textbf{x}_i) yi=f(xi)是对应的标量标签,核回归的目标是构建一个估计函数
f ^ ( x ) = ∑ i = 1 n ( K − 1 y ) i k ( x i , x ) \hat{f}(\textbf{x})=\sum_{i=1}^n(\textbf{K}^{-1}\textbf{y})_i k(\textbf{x}_i,\textbf{x}) \\ f^(x)=i=1n(K1y)ik(xi,x)
其中 K \textbf{K} K n × n n\times n n×n的核矩阵,该矩阵的每个分量为 K i j = k ( x i , x j ) \textbf{K}_{ij}=k(\textbf{x}_i,\textbf{x}_j) Kij=k(xi,xj) k k k是对称半正定核函数。

​ 直觉上,核回归对于任意数据点 x \textbf{x} x的估计值可以看做是训练数据 x i \textbf{x}_i xi x \textbf{x} x的相似性作为权重,然后对训练标签 y i y_i yi进行加权求和。

二、预测的演化方程

​ 设神经网络的输出表示为 f ( w , x ) ∈ R f(w,x)\in\mathbb{R} f(w,x)R,其中 w ∈ R N w\in\mathbb{R}^N wRN是网络中的所有参数, x ∈ R d x\in\mathbb{R}^d xRd是输入。给定训练数据 { ( x i , y i ) } i = 1 n ⊂ R d × R \{(x_i,y_i)\}_{i=1}^n\subset\mathbb{R}^d\times\mathbb{R} {(xi,yi)}i=1nRd×R,通过最小化训练数据上的均方误差来训练神经网络:
l ( w ) = 1 2 ∑ i = 1 n ( f ( w , x i ) − y i ) 2 (1) \mathcal{l}(w)=\frac{1}{2}\sum_{i=1}^n(f(w,x_i)-y_i)^2 \tag{1} \\ l(w)=21i=1n(f(w,xi)yi)2(1)
这里主要研究梯度流(gradient flow),也就是极小学习率的梯度下降。在上面的例子中,预测的动力学可以描述为常微分方程:
d w ( t ) d t = − ∇ l ( w ( t ) ) (2) \frac{d w(t)}{dt}=-\nabla\mathcal{l}(w(t)) \tag{2} \\ dtdw(t)=l(w(t))(2)

引理1

u ( t ) = ( f ( w ( t ) , x i ) ) i ∈ [ n ] ∈ R n u(t)=(f(w(t),x_i))_{i\in[n]}\in\mathbb{R}^n u(t)=(f(w(t),xi))i[n]Rn表示神经网络在时刻 t t t的所有输出 x i ′ x_i' xi y = ( y i ) i ∈ [ n ] y=(y_i)_{i\in[n]} y=(yi)i[n]是标签。 u ( t ) u(t) u(t)的演化遵循
d u ( t ) d t = − H ( t ) ⋅ ( u ( t ) − y ) (3) \frac{du(t)}{dt}=-H(t)\cdot(u(t)-y) \tag{3} \\ dtdu(t)=H(t)(u(t)y)(3)
其中, H ( t ) H(t) H(t) n × n n\times n n×n的半正定矩阵,其第 ( i , j ) (i,j) (i,j)个元素是 ⟨ ∂ f ( w ( t ) , x i ) ∂ w , ∂ f ( w ( t ) , x j ) ∂ w ⟩ \langle\frac{\partial f(w(t),x_i)}{\partial w},\frac{\partial f(w(t),x_j)}{\partial w}\rangle wf(w(t),xi),wf(w(t),xj)

证明。参数 w w w的演化是基于下面的微分方程
d w ( t ) d t = − ∇ l ( w ( t ) ) = − ∑ i = 1 n ( f ( w ( t ) , x i ) − y i ) ∂ f ( w ( t ) , x i ) ∂ w (4) \frac{dw(t)}{dt}=-\nabla\mathcal{l}(w(t))=-\sum_{i=1}^n(f(w(t),x_i)-y_i)\frac{\partial f(w(t),x_i)}{\partial w} \tag{4} \\ dtdw(t)=l(w(t))=i=1n(f(w(t),xi)yi)wf(w(t),xi)(4)
其中 t ≥ 0 t\geq 0 t0是连续的时间坐标。基于等式(4),网络输出 f ( w ( t ) , x i ) f(w(t),x_i) f(w(t),xi)的演化可以写作
d f ( w ( t ) , x i ) d t = ⟨ ∂ f ( w ( t ) , x i ) ∂ w ( t ) , ∂ w ( t ) ∂ t ⟩ = ⟨ ∂ f ( w ( t ) , x i ) ∂ w ( t ) , − ∑ j = 1 n ( f ( w ( t ) , x j ) − y j ) ∂ f ( w ( t ) , x j ) ∂ w ⟩ = − ∑ j = 1 n ( f ( w ( t ) , x j ) , y j ) ⟨ ∂ f ( w ( t ) , x i ) ∂ w , ∂ f ( w ( t ) , x j ) ∂ w ⟩ (5) \begin{align} \frac{df(w(t),x_i)}{dt}&=\Big\langle\frac{\partial f(w(t),x_i)}{\partial w(t)},\frac{\partial w(t)}{\partial t}\Big\rangle \\ &=\Big\langle \frac{\partial f(w(t),x_i)}{\partial w(t)}, -\sum_{j=1}^n(f(w(t),x_j)-y_j)\frac{\partial f(w(t),x_j)}{\partial w} \Big\rangle \\ &=-\sum_{j=1}^n(f(w(t),x_j),y_j)\Big\langle \frac{\partial f(w(t),x_i)}{\partial w}, \frac{\partial f(w(t),x_j)}{\partial w}\Big\rangle \\ \end{align} \tag{5} \\ dtdf(w(t),xi)=w(t)f(w(t),xi),tw(t)=w(t)f(w(t),xi),j=1n(f(w(t),xj)yj)wf(w(t),xj)=j=1n(f(w(t),xj),yj)wf(w(t),xi),wf(w(t),xj)(5)
因为 u ( t ) = ( f ( w ( t ) , x i ) ) i ∈ [ n ] ∈ R n u(t)=(f(w(t),x_i))_{i\in[n]}\in\mathbb{R}^n u(t)=(f(w(t),xi))i[n]Rn是神经网络 t t t时刻在所有 x i x_i xi上的输出, y = ( y i ) i ∈ [ n ] y=(y_i)_{i\in[n]} y=(yi)i[n]是标签。等式(5)可以紧凑的写作
d u ( t ) d t = − H ( t ) ⋅ ( u ( t ) − y ) (6) \frac{du(t)}{dt}=-H(t)\cdot(u(t)-y) \tag{6} \\ dtdu(t)=H(t)(u(t)y)(6)
其中 H ( t ) ∈ R n × n H(t)\in\mathbb{R}^{n\times n} H(t)Rn×n是定义为 [ H ( t ) ] i , j = ⟨ ∂ f ( w ( t ) , x i ) ∂ w , ∂ f ( w ( t ) , x j ) ∂ w ⟩ ( ∀ i , j ∈ [ n ] ) [H(t)]_{i,j}=\langle\frac{\partial f(w(t),x_i)}{\partial w},\frac{\partial f(w(t),x_j)}{\partial w} \rangle(\forall i,j\in[n]) [H(t)]i,j=wf(w(t),xi),wf(w(t),xj)(i,j[n])

​ 上面引理涉及到矩阵 H ( t ) H(t) H(t)。下面将会定义一个无限宽的神经网络,并固定训练数据。在这种限制下,训练过程中的矩阵 H ( t ) H(t) H(t)为常数,即 H ( t ) H(t) H(t)的等于 H ( 0 ) H(0) H(0)。此外,对于随机初始化参数,当网络宽度为无限时,随机矩阵 H ( 0 ) H(0) H(0)概率收敛至某个确定的核矩阵 H ∗ H^* H,该矩阵就是通过训练数据估计出的神经正切核(Neural Tangent Kernel, NTK) k ( ⋅ , ⋅ ) k(\cdot,\cdot) k(,)。若对于所有 t t t均有 H ( t ) = H ∗ H(t)=H^* H(t)=H,那么等式(3)就变成
d u ( t ) d t = − H ∗ ⋅ ( u ( t ) − y ) (7) \frac{d u(t)}{dt}=-H^*\cdot(u(t)-y) \tag{7} \\ dtdu(t)=H(u(t)y)(7)
可以发现上述公式的动力学与梯度流下的核回归一致,那么当 t → ∞ t\rightarrow\infty t时最终的预测函数为
f ∗ ( x ) = ( k ( x , x 1 ) , … , k ( x , x n ) ) ⋅ ( H ∗ ) − 1 y (8) f^*(x)=(k(x,x_1),\dots,k(x,x_n))\cdot(H^*)^{-1}y\tag{8} \\ f(x)=(k(x,x1),,k(x,xn))(H)1y(8)

三、无限宽网络与神经正切核(NTK)

​ 下面是一个简单的两层神经网络
f ( a , W , x ) = 1 m ∑ r = 1 m a r σ ( w r T x ) (9) f(a,W,x)=\frac{1}{\sqrt{m}}\sum_{r=1}^m a_r\sigma(w_r^Tx) \tag{9} \\ f(a,W,x)=m 1r=1marσ(wrTx)(9)
其中 m m m是网络的宽度, σ ( ⋅ ) \sigma(\cdot) σ()是激活函数。这里假设对于所有的 z ∈ R z\in\mathbb{R} zR ∣ σ ′ ( z ) ∣ |\sigma'(z)| σ(z) ∣ σ ′ ′ ( z ) ∣ |\sigma''(z)| σ′′(z)的上界均为1,例如 σ ( z ) = log ⁡ ( 1 + exp ⁡ ( z ) ) \sigma(z)=\log(1+\exp(z)) σ(z)=log(1+exp(z))就满足这个假设。假设所有的输入 x x x的Euclidean范数均为1,即 ∥ x ∥ 2 = 1 \parallel x\parallel_2=1 x2=1。缩放因子 1 m \frac{1}{\sqrt{m}} m 1在证明 H ( t ) H(t) H(t)接近于固定核 H ∗ H^* H上扮演者重要的角色。使用范式 ∥ ⋅ ∥ 2 \parallel\cdot\parallel_2 2来衡量两个矩阵 A A A B B B的接近程度。

先计算 H ( 0 ) H(0) H(0),并展示 m → ∞ m\rightarrow\infty m H ( 0 ) H(0) H(0)收敛至固定矩阵 H ∗ H^* H 注意, ∂ f ( a , W , x i ) ∂ w r = 1 m a r x i σ ′ ( w r ⊤ x i ) \frac{\partial f(a,W,x_i)}{\partial w_r}=\frac{1}{\sqrt{m}}a_r x_i\sigma'(w_r^\top x_i) wrf(a,W,xi)=m 1arxiσ(wrxi)。因此, H ( 0 ) H(0) H(0)中的每个元素为
[ H ( 0 ) ] i j = ∑ r = 1 m ⟨ ∂ f ( a , W ( 0 ) , x i ) ∂ w r ( 0 ) , ∂ f ( a , W ( 0 ) , x j ) ∂ w r ( 0 ) ⟩ = ∑ r = 1 m ⟨ 1 m a r x i σ ′ ( w r ( 0 ) ⊤ x i ) , 1 m a r x j σ ′ ( w r ( 0 ) ⊤ x i ) ⟩ = x i ⊤ x j ⋅ ∑ r = 1 m σ ′ ( w r ( 0 ) ⊤ x i ) σ ′ ( w r ( 0 ) ⊤ x j ) m (8) \begin{align} [H(0)]_{ij}&=\sum_{r=1}^m\Big\langle \frac{\partial f(a,W(0),x_i)}{\partial w_r(0)},\frac{\partial f(a,W(0),x_j)}{\partial w_r(0)} \Big\rangle \\ &=\sum_{r=1}^m\Big\langle\frac{1}{\sqrt{m}}a_rx_i\sigma'(w_r(0)^\top x_i),\frac{1}{\sqrt{m}}a_rx_j\sigma'(w_r(0)^\top x_i)\Big\rangle \\ &=x_i^\top x_j\cdot\frac{\sum_{r=1}^m\sigma'(w_r(0)^\top x_i)\sigma'(w_r(0)^\top x_j)}{m} \\ \end{align} \tag{8} \\ [H(0)]ij=r=1mwr(0)f(a,W(0),xi),wr(0)f(a,W(0),xj)=r=1mm 1arxiσ(wr(0)xi),m 1arxjσ(wr(0)xi)=xixjmr=1mσ(wr(0)xi)σ(wr(0)xj)(8)
最后一步,由于 a r ∼ Unif [ { − 1 , 1 } ] a_r\sim\text{Unif}[\{-1,1\}] arUnif[{1,1}],因此对于所有的 r = 1 , … , m r=1,\dots,m r=1,,m,有 a r 2 = 1 a_r^2=1 ar2=1。对于所有的 w r ( 0 ) w_r(0) wr(0)都是从标准高斯分布中独立同分布采样出来的。因此,可以将 [ H ( 0 ) ] i j [H(0)]_{ij} [H(0)]ij看做是m个独立同分布随机变量的平均值。若 m m m很大,那么基于大数定律,这个平均值接近于随机变量的期望。在 x i x_i xi x j x_j xj上由NTK评估的期望为:
H i j ∗ ≜ x i ⊤ x j ⋅ E w ∼ N ( 0 , I ) [ σ ′ ( w ⊤ x i ) σ ′ ( w T x j ) ] (9) H_{ij}^*\triangleq x_i^\top x_j\cdot\mathbb{E}_{w\sim N(0,I)}[\sigma'(w^\top x_i)\sigma'(w^T x_j)] \tag{9} \\ HijxixjEwN(0,I)[σ(wxi)σ(wTxj)](9)
基于Hoeffding不等式和Boole不等式,可以容易得知 H ( 0 ) H(0) H(0)逼近于 H ∗ H^* H

引理2

​ 对于某个 ϵ > 0 \epsilon>0 ϵ>0。若 m = Ω ( n 4 log ⁡ ( n / δ ) ϵ 2 ) m=\Omega(\frac{n^4\log(n/\delta)}{\epsilon^2}) m=Ω(ϵ2n4log(n/δ)),那么 w 1 ( 0 ) , … , w m ( 0 ) w_1(0),\dots,w_m(0) w1(0),,wm(0)至少以概率 1 − δ 1-\delta 1δ满足
∥ H ( 0 ) − H ∗ ∥ 2 ≤ ϵ \parallel H(0)-H^*\parallel_2\leq\epsilon \\ H(0)H2ϵ
证明。对于分量 ( i , j ) (i,j) (i,j),由于 ∣ σ ′ ( z ) ∣ ≤ 1 |\sigma'(z)|\leq 1 σ(z)1 ∥ x ∥ = 1 \parallel x\parallel=1 x∥=1,那么有
∣ x i ⊤ x j σ ′ ( w t ( 0 ) ⊤ x i ) σ ′ ( w r ( 0 ) ⊤ x j ) ∣ ≤ 1 |x_i^\top x_j\sigma'(w_t(0)^\top x_i)\sigma'(w_r(0)^\top x_j)|\leq 1 \\ xixjσ(wt(0)xi)σ(wr(0)xj)1
因此, [ H ( 0 ) ] i j [H(0)]_{ij} [H(0)]ij的边界为 [ 0 , 1 ] [0,1] [0,1]。应用Hoeffding不等式,有
P ( ∣ [ H ( 0 ) ] i j − H i j ∗ ∣ ≥ ϵ n 2 ) ≤ exp ⁡ ( − 2 m 2 ( ϵ n 2 ) 2 ∑ i = 1 m ( 1 − 0 ) 2 ) = exp ⁡ ( − 2 m ϵ 2 n 4 ) ≤ exp ⁡ ( − 2 ϵ 2 n 4 n 4 log ⁡ ( n / δ ) ϵ 2 ) = exp ⁡ ( − 2 log ⁡ ( n / δ ) ) = δ 2 n 2 ≤ δ n 2 \begin{align} P\Big(|[H(0)]_{ij}-H_{ij}^*|\geq \frac{\epsilon}{n^2}\Big)&\leq \exp(-\frac{2m^2(\frac{\epsilon}{n^2})^2}{\sum_{i=1}^m(1-0)^2}) \\ &=\exp(-\frac{2m\epsilon^2}{n^4}) \\ &\leq\exp(-\frac{2\epsilon^2}{n^4}\frac{n^4\log(n/\delta)}{\epsilon^2}) \\ &=\exp(-2\log(n/\delta)) \\ &=\frac{\delta^2}{n^2}\leq\frac{\delta}{n^2} \\ \end{align} \\ P([H(0)]ijHijn2ϵ)exp(i=1m(10)22m2(n2ϵ)2)=exp(n42mϵ2)exp(n42ϵ2ϵ2n4log(n/δ))=exp(2log(n/δ))=n2δ2n2δ
(注: n n n是训练样本数, m m m是网络宽度)

那么有
P ( ∣ [ H ( 0 ) ] i j − H i j ∗ ∣ ≤ ϵ n 2 ) = 1 − P ( ∣ [ H ( 0 ) ] i j − H i j ∗ ∣ ≥ ϵ n 2 ) ≥ 1 − δ n 2 \begin{align} P\Big(|[H(0)]_{ij}-H_{ij}^*|\leq \frac{\epsilon}{n^2}\Big)&=1-P\Big(|[H(0)]_{ij}-H_{ij}^*|\geq \frac{\epsilon}{n^2}\Big)\geq 1-\frac{\delta}{n^2} \\ \end{align} \\ P([H(0)]ijHijn2ϵ)=1P([H(0)]ijHijn2ϵ)1n2δ
将上面的结论应用在所有 ( i , j ) ∈ [ n ] × [ n ] (i,j)\in[n]\times[n] (i,j)[n]×[n],并使用Boole不等式
∥ H ( 0 ) − H ∗ ∥ 2 ≤ ∥ H ( 0 ) − H ∗ ∥ F ≤ ∑ i j ∣ [ H ( 0 ) ] i j − H i j ∗ ∣ ≤ n 2 ⋅ ϵ n 2 = ϵ \parallel H(0)-H^* \parallel_2\leq\parallel H(0)-H^* \parallel_F\leq\sum_{ij}|[H(0)]_{ij}-H_{ij}^*|\leq n^2\cdot\frac{\epsilon}{n^2}=\epsilon \\ H(0)H2≤∥H(0)HFij[H(0)]ijHijn2n2ϵ=ϵ

接下来证明在训练过程中, H ( t ) H(t) H(t)逼近 H ( 0 ) H(0) H(0)

引理3

​ 假设对于所有的 i = 1 , … , n i=1,\dots,n i=1,,n都有 y i = O ( 1 ) y_i=O(1) yi=O(1)。给定 t > 0 t>0 t>0,对任意的 0 ≤ τ ≤ t 0\leq\tau\leq t 0τt,所有的 i = 1 , … , n i=1,\dots,n i=1,,n都有 u i ( τ ) = O ( 1 ) u_i(\tau)=O(1) ui(τ)=O(1)。若 m = Ω ( n 6 t 2 ϵ 2 ) m=\Omega(\frac{n^6t^2}{\epsilon^2}) m=Ω(ϵ2n6t2),有
∥ H ( t ) − H ( 0 ) ∥ 2 ≤ ϵ \parallel H(t)-H(0) \parallel_2\leq\epsilon \\ H(t)H(0)2ϵ
(直观解释:若所有样本的标签值均不大于1,且0到 t t t时刻中的任意时刻 τ \tau τ,模型的预测值也不大于1。那么当网络宽度 m m m大于 n 6 t 2 ϵ 2 \frac{n^6t^2}{\epsilon^2} ϵ2n6t2时, t t t时刻的NTK核逼近于初始的NTK核)。

证明。第一个关键思想是:当 m m m很大时,每个权重向量变化量很小。下面是单个权重向量的变化
∥ w r ( t ) − w r ( 0 ) ∥ 2 = ∥ ∫ 0 t d w r ( τ ) d τ d τ ∥ 2 = ∥ ∫ 0 t ∑ i = 1 n ( u i ( τ ) − y i ) ∂ u i ( τ ) ∂ w d τ ∥ 2 = ∥ ∫ 0 t ∑ i = 1 n ( u i ( τ ) − y i ) 1 m a r x i σ ′ ( w r ( τ ) ⊤ x i ) d τ ∥ 2 ≤ 1 m ∫ ∥ ∑ i = 1 n ( u i ( τ ) − y i ) a r x i σ ′ ( w r ( τ ) ⊤ x i ) ∥ 2 d τ ≤ 1 m ∑ i = 1 n ∫ 0 t ∥ u i ( τ ) − y i a r x i σ ′ ( w r ( τ ) ⊤ x i ) ∥ 2 d τ ≤ 1 m ∑ i = 1 n ∫ 0 t O ( 1 ) d τ = O ( t n m ) \begin{align} \parallel w_r(t)-w_r(0) \parallel_2&=\Big\| \int_{0}^t\frac{dw_r(\tau)}{d\tau}d\tau \Big\|_2 \\ &=\Big\|\int_{0}^t \sum_{i=1}^n(u_i(\tau)-y_i)\frac{\partial u_i(\tau)}{\partial w} d\tau \Big\|_2 \\ &=\Big\| \int_{0}^t\sum_{i=1}^n(u_i(\tau)-y_i)\frac{1}{\sqrt{m}}a_rx_i\sigma'(w_r(\tau)^\top x_i) d\tau \Big\|_2 \\ &\leq\frac{1}{\sqrt{m}}\int\Big\|\sum_{i=1}^n(u_i(\tau)-y_i)a_rx_i\sigma'(w_r(\tau)^\top x_i) \Big\|_2d\tau \\ &\leq\frac{1}{\sqrt{m}}\sum_{i=1}^n\int_{0}^t\| u_i(\tau)-y_ia_rx_i\sigma'(w_r(\tau)^\top x_i) \|_2 d\tau \\ &\leq\frac{1}{\sqrt{m}}\sum_{i=1}^n\int_{0}^t O(1) d\tau=O(\frac{tn}{\sqrt{m}}) \\ \end{align} \\ wr(t)wr(0)2= 0tdτdwr(τ)dτ 2= 0ti=1n(ui(τ)yi)wui(τ)dτ 2= 0ti=1n(ui(τ)yi)m 1arxiσ(wr(τ)xi)dτ 2m 1 i=1n(ui(τ)yi)arxiσ(wr(τ)xi) 2dτm 1i=1n0tui(τ)yiarxiσ(wr(τ)xi)2dτm 1i=1n0tO(1)dτ=O(m tn)
上面的结果表明:给定任意 t t t,只要 m m m足够大,则 w r ( t ) w_r(t) wr(t)就接近于 w r ( 0 ) w_r(0) wr(0)。下面将证明这意味着核矩阵 H ( t ) H(t) H(t)接近于 H ( 0 ) H(0) H(0)。这里证明单个分量的差距
[ H ( t ) ] i j − [ H ( 0 ) ] i j = ∣ 1 m ∑ r = 1 m ( σ ′ ( w r ( t ) ⊤ x i ) σ ′ ( w r ( t ) ⊤ x j ) − σ ′ ( w r ( 0 ) ⊤ x i ) σ ′ ( w r ( 0 ) ⊤ x j ) ) ∣ ≤ 1 m ∑ r = 1 m ∣ σ ′ ( w r ( t ) ⊤ x i ) ( σ ′ ( w r ( t ) ⊤ x j ) − σ ′ ( w r ( 0 ) ⊤ x j ) ) ∣ + 1 m ∑ r = 1 m ∣ σ ′ ( w r ( 0 ) ⊤ x j ) ( σ ′ ( w r ( t ) ⊤ x j ) − σ ′ ( w r ( 0 ) ⊤ x i ) ) ∣ ≤ 1 m ∑ r = 1 m ∣ max ⁡ r σ ′ ( w r ( t ) ⊤ x i ) ∥ x i ∥ 2 ∥ w r ( t ) − w r ( 0 ) ∥ 2 ∣ + 1 m ∑ r = 1 m ∣ max ⁡ r σ ′ ( w r ( t ) ⊤ x i ) ∥ x i ∥ 2 ∥ w r ( t ) − w r ( 0 ) ∥ 2 ∣ = 1 m ∑ r = 1 m O ( t n m ) \begin{align} &[H(t)]_{ij}-[H(0)]_{ij} \\ =&\Big| \frac{1}{m}\sum_{r=1}^m\Big( \sigma'(w_r(t)^\top x_i)\sigma'(w_r(t)^\top x_j)- \sigma'(w_r(0)^\top x_i)\sigma'(w_r(0)^\top x_j)\Big) \Big| \\ \leq&\frac{1}{m}\sum_{r=1}^m\Big|\sigma'(w_r(t)^\top x_i)(\sigma'(w_r(t)^\top x_j)-\sigma'(w_r(0)^\top x_j)) \Big| \\ &+\frac{1}{m}\sum_{r=1}^m\Big|\sigma'(w_r(0)^\top x_j)(\sigma'(w_r(t)^\top x_j)-\sigma'(w_r(0)^\top x_i)) \Big| \\ \leq&\frac{1}{m}\sum_{r=1}^m\Big|\max_r \sigma'(w_r(t)^\top x_i)\|x_i\|_2\| w_r(t)-w_r(0) \|_2 \Big| \\ &+\frac{1}{m}\sum_{r=1}^m\Big|\max_r \sigma'(w_r(t)^\top x_i)\|x_i\|_2\| w_r(t)-w_r(0) \|_2 \Big| \\ =&\frac{1}{m}\sum_{r=1}^m O(\frac{tn}{\sqrt{m}}) \\ \end{align} \\ ==[H(t)]ij[H(0)]ij m1r=1m(σ(wr(t)xi)σ(wr(t)xj)σ(wr(0)xi)σ(wr(0)xj)) m1r=1m σ(wr(t)xi)(σ(wr(t)xj)σ(wr(0)xj)) +m1r=1m σ(wr(0)xj)(σ(wr(t)xj)σ(wr(0)xi)) m1r=1m rmaxσ(wr(t)xi)xi2wr(t)wr(0)2 +m1r=1m rmaxσ(wr(t)xi)xi2wr(t)wr(0)2 m1r=1mO(m tn)
因此,有
∥ H ( t ) − H ( 0 ) ∥ 2 ≤ ∑ i , j ∣ [ H ( t ) ] i j − [ H ( 0 ) ] i j ∣ = O ( t n 3 m ) \| H(t)-H(0)\|_2\leq\sum_{i,j}\Big|[H(t)]_{ij}-[H(0)]_{ij} \Big|=O\Big(\frac{tn^3}{\sqrt{m}}\Big) \\ H(t)H(0)2i,j [H(t)]ij[H(0)]ij =O(m tn3)

四、用NTK解释无限宽网络的优化和泛化

​ 基于上面的结论有
d u ( t ) d t ≈ − H ∗ ⋅ ( u ( t ) − y ) (10) \frac{du(t)}{d_t}\approx -H^*\cdot(u(t)-y) \tag{10}\\ dtdu(t)H(u(t)y)(10)
其中 H ∗ H^* H是NTK矩阵。接下来基于该近似分析无限宽神经网络的优化和泛化。

1. 优化

U ( t ) U(t) U(t)的动力学遵循
d u ( t ) d t = − H ∗ ⋅ ( u ( t ) − y ) (11) \frac{du(t)}{d_t}= -H^*\cdot(u(t)-y) \tag{11}\\ dtdu(t)=H(u(t)y)(11)
本质上是线性动力系统。对 H ∗ H^* H进行特征值分解的
H ∗ = ∑ i = 1 n λ i v i v i ⊤ (12) H^*=\sum_{i=1}^n\lambda_i v_i v_i^\top \tag{12}\\ H=i=1nλivivi(12)
其中 λ 1 ≥ ⋯ ≥ λ n ≥ 0 \lambda_1\geq\dots\geq\lambda_n\geq 0 λ1λn0是特征值, v 1 , … , v n v_1,\dots,v_n v1,,vn是特征向量。基于该分解可以分别研究 u ( t ) u(t) u(t)在每个特征向量上的动力学。对等式(12)两边同时乘以 v i v_i vi得,得到 u ( t ) u(t) u(t)在特征向量 v i v_i vi上的动力学
d v i ⊤ u ( t ) d t = − v i ⊤ H ∗ ⋅ ( u ( t ) − y ) = − v i ⊤ ∑ i = 1 n λ i v i v i ⊤ ⋅ ( u ( t ) − y ) = − λ i ( v i ⊤ ( u ( t ) − y ) ) (13) \begin{align} \frac{dv_i^\top u(t)}{dt}&=-v_i^\top H^*\cdot(u(t)-y) \\ &=-v_i^\top\sum_{i=1}^n\lambda_i v_i v_i^\top\cdot(u(t)-y) \\ &=-\lambda_i(v_i^\top(u(t)-y)) \\ \end{align} \tag{13}\\ dtdviu(t)=viH(u(t)y)=vii=1nλivivi(u(t)y)=λi(vi(u(t)y))(13)
可以看到 v i ⊤ u ( t ) v_i^\top u(t) viu(t)的动力学仅依赖于其本身和 λ i \lambda_i λi,这其实是一个常微分方程。该常微分方程的一个解析解为
v i ⊤ ( u ( t ) − y ) = exp ⁡ ( − λ i t ) ( v i ⊤ ( u ( 0 ) − y ) ) (14) v_i^\top(u(t)-y)=\exp(-\lambda_i t)\Big(v_i^\top(u(0)-y) \Big) \tag{14}\\ vi(u(t)y)=exp(λit)(vi(u(0)y))(14)

现在使用上面的等式来解释为什么可以找到0训练误差解。假设对于所有的 i = 1 , … , n i=1,\dots,n i=1,,n均有 λ i > 0 \lambda_i>0 λi>0,即核矩阵的所有特征值均严格为正。

( u ( t ) − y ) (u(t)-y) (u(t)y)表示 t t t时刻预测值和训练标签之间的差值。若当 t → ∞ t\rightarrow\infty t,有 u ( t ) − y → 0 u(t)-y\rightarrow 0 u(t)y0时,表示存在一个训练误差为0的算法。等式(14)表示该差值的分量,由于项 exp ⁡ ( − λ i t ) \exp(-\lambda_i t) exp(λit),所以 v i ⊤ ( u ( t ) − y ) v_i^\top(u(t)-y) vi(u(t)y)会以指数级的速度收敛至0。此外,由于 { v 1 , … , v n } \{v_1,\dots,v_n\} {v1,,vn} R n \mathbb{R}^n Rn上的一个正交基,因此 ( u ( t ) − y ) = ∑ i = 1 n v i ⊤ ( u ( t ) − y ) (u(t)-y)=\sum_{i=1}^nv_i^\top(u(t)-y) (u(t)y)=i=1nvi(u(t)y)。因此,当每个 v i ⊤ ( u i ( t ) − y ) → 0 v_i^\top(u_i(t)-y)\rightarrow 0 vi(ui(t)y)0,可以得到 ( u ( t ) − y ) → 0 (u(t)-y)\rightarrow 0 (u(t)y)0

​ 等式(14)本质上给出了关于收敛相关的信息,即每个分量 v i ⊤ ( u ( t ) − y ) v_i^\top(u(t)-y) vi(u(t)y)以不同的速率收敛至0。较大的 λ i \lambda_i λi对应的分量收敛到0的速度快于较小的 λ i \lambda_i λi。若期望在给定标签下能够更快的收敛,那么 y y y投影至顶部的特征应该更大。因此,可以通过下面直观的来定性比较收敛速度

  • 若标签集合 y y y对齐至顶部特征,即 ( v i ⊤ y ) (v_i^\top y) (viy)对应较大的特征值,那么梯度下降收敛较快;
  • 若标签集合 y y y投影至特征向量 { ( v i ⊤ y ) } i = 1 n \{(v_i^\top y)\}_{i=1}^n {(viy)}i=1n是均匀分布,那么梯度下降的收敛速度就较慢;

2. 泛化

​ 等式(10)中的近似意味着无限宽神经网络最终预测的函数近似于等式(8)的核预测函数。因此,可以使用核的泛化理论来分析无限宽神经网络的泛化行为。等式(8)中定义的核预测函数,使用Rademacher复杂度边界来推断下面1-Lipschitz损失函数的泛化边界
2 y ⊤ ( H ∗ ) − 1 y ⋅ t r ( H ∗ ) n (15) \frac{\sqrt{2y^\top(H^*)^{-1}y\cdot tr(H^*)}}{n} \tag{15}\\ n2y(H)1ytr(H) (15)
这是一个依赖于数据的复杂度度量的泛化误差上界。

五、多层全连接神经网络的NTK形式

​ 先来定义全连接神经网络。令 x ∈ R d x\in\mathbb{R}^d xRd表示输入,为了方便令 g ( 0 ) ( x ) = x g^{(0)}(x)=x g(0)(x)=x d 0 = d d_0=d d0=d。那么 L L L层全连接神经网络表示为
f ( h ) ( x ) = W ( h ) g ( h − 1 ) ( x ) ∈ R d h , g ( h ) ( x ) = c σ d h σ ( f ( h ) ( x ) ) ∈ R d h (16) f^{(h)}(x)=W^{(h)}g^{(h-1)}(x)\in\mathbb{R}^{d_h},g^{(h)}(x)=\sqrt{\frac{c_{\sigma}}{d_h}}\sigma\Big(f^{(h)}(x)\Big)\in\mathbb{R}^{d_h} \tag{16}\\ f(h)(x)=W(h)g(h1)(x)Rdh,g(h)(x)=dhcσ σ(f(h)(x))Rdh(16)
其中 h = 1 , 2 , … , L h=1,2,\dots,L h=1,2,,L W ( h ) ∈ R d h × d h − 1 W^{(h)}\in\mathbb{R}^{d_h\times d_{h-1}} W(h)Rdh×dh1表示第 h h h层的权重矩阵, σ : R → R \sigma:\mathbb{R}\rightarrow\mathbb{R} σ:RR是激活函数, c σ = ( E z ∼ N ( 0 , 1 ) [ σ z 2 ] ) − 1 c_{\sigma}=\Big(E_{z\sim\mathcal{N}(0,1)}[\sigma z^2]\Big)^{-1} cσ=(EzN(0,1)[σz2])1。神经网络的最后一层来自于
f ( w , x ) = f ( L + 1 ) ( x ) = W ( L + 1 ) ⋅ g ( L ) ( x ) = W ( L + 1 ) ⋅ c σ d L σ W ( L ) ⋅ c σ d L − 1 σ W ( L − 1 ) ⋯ ⋅ c σ d 1 σ W ( 1 ) x (17) \begin{align} f(w,x)&=f^{(L+1)}(x)=W^{(L+1)}\cdot g^{(L)}(x) \\ &=W^{(L+1)}\cdot\sqrt{\frac{c_{\sigma}}{d_L}}\sigma W^{(L)}\cdot\sqrt{\frac{c_{\sigma}}{d_{L-1}}}\sigma W^{(L-1)}\dots \cdot\sqrt{\frac{c_{\sigma}}{d_1}}\sigma W^{(1)}x \end{align} \tag{17}\\ f(w,x)=f(L+1)(x)=W(L+1)g(L)(x)=W(L+1)dLcσ σW(L)dL1cσ σW(L1)d1cσ σW(1)x(17)
其中 W ( L + 1 ) ∈ R 1 × d L W^{(L+1)}\in\mathbb{R}^{1\times d_L} W(L+1)R1×dL表示最后一层的权重, w = ( W ( 1 ) , … , W ( L + 1 ) ) w=\Big(W^{(1)},\dots,W^{(L+1)}\Big) w=(W(1),,W(L+1))表示神经网络的所有权重。

​ 使用标准正态分布来初始化权重并考虑hidden宽度的极限为: d 1 , d 2 , … , d L → ∞ d_1,d_2,\dots,d_L\rightarrow\infty d1,d2,,dL。缩放因子 c σ / d h \sqrt{c_{\sigma}/d_h} cσ/dh 用于确保 g ( h ) ( x ) g^{(h)}(x) g(h)(x)近似于初始化。对于ReLU集合函数,有
E [ ∥ g ( h ) ( x ) ∥ 2 2 ] = ∥ x ∥ 2 2 ( ∀ h ∈ [ L ] ) (18) E\Big[\Big\| g^{(h)}(x) \Big\|_2^2\Big]=\|x\|_2^2(\forall h\in[L]) \tag{18} \\ E[ g(h)(x) 22]=x22(h[L])(18)
正如引理1中需要计算 ⟨ ∂ f ( w ( t ) , x ) ∂ w , ∂ f ( w ( t ) , x ′ ) ∂ w ⟩ \langle\frac{\partial f(w(t),x)}{\partial w},\frac{\partial f(w(t),x')}{\partial w}\rangle wf(w(t),x),wf(w(t),x)在无限宽下收敛至随机初始化。可以将关于特定权重矩阵 W ( h ) W^{(h)} W(h)的偏导数写作
∂ f ( w , x ) ∂ W ( h ) = b ( h ) ( x ) ⋅ ( g ( h − 1 ) ( x ) ) ⊤ , h = 1 , 2 , … , L + 1 (19) \frac{\partial f(w,x)}{\partial W^{(h)}}=b^{(h)}(x)\cdot\Big(g^{(h-1)}(x)\Big)^\top,\quad h=1,2,\dots,L+1 \tag{19} \\ W(h)f(w,x)=b(h)(x)(g(h1)(x)),h=1,2,,L+1(19)
其中
b ( h ) ( x ) = { 1 ∈ R , h = L + 1 c σ d h D ( h ) ( x ) ( W ( h + 1 ) ) ⊤ b ( h + 1 ) ( x ) ∈ R d h , h = 1 , … , L (20) b^{(h)}(x)=\begin{cases} 1\in\mathbb{R},& h=L+1 \\ \sqrt{\frac{c_\sigma}{d_h}}D^{(h)}(x)\Big(W^{(h+1)} \Big)^\top b^{(h+1)}(x)\in\mathbb{R}^{d_h},& h=1,\dots,L \end{cases} \tag{20} \\ b(h)(x)= 1R,dhcσ D(h)(x)(W(h+1))b(h+1)(x)Rdh,h=L+1h=1,,L(20)

KaTeX parse error: Expected 'EOF', got '&' at position 93: …d_h\times d_h},&̲h=1,\dots,L \ta…

对于两个任意的输入 x x x x ′ x' x,任意的 h ∈ [ L + 1 ] h\in[L+1] h[L+1],可以计算
⟨ ∂ f ( w , x ) ∂ W ( h ) , ∂ f ( w , x ′ ) ∂ W ( h ) ⟩ = ⟨ b ( h ) ( x ) ⋅ ( g ( h − 1 ) ( x ) ) ⊤ , b ( h ) ( x ′ ) ⋅ ( g ( h − 1 ) ( x ′ ) ) ⊤ ⟩ = ⟨ g ( h − 1 ) ( x ) , g ( h − 1 ) ( x ′ ) ⟩ ⋅ ⟨ b ( h ) ( x ) , b ( h ) ( x ′ ) ⟩ (22) \begin{align} &\Big\langle\frac{\partial f(w,x)}{\partial W^{(h)}},\frac{\partial f(w,x')}{\partial W^{(h)}}\Big\rangle \\ =&\Big\langle b^{(h)}(x)\cdot\Big(g^{(h-1)}(x)\Big)^\top, b^{(h)}(x')\cdot\Big(g^{(h-1)}(x')\Big)^\top\Big\rangle \\ =&\langle g^{(h-1)}(x),g^{(h-1)}(x') \rangle\cdot\langle b^{(h)}(x),b^{(h)}(x') \rangle \\ \end{align} \tag{22}\\ ==W(h)f(w,x),W(h)f(w,x)b(h)(x)(g(h1)(x)),b(h)(x)(g(h1)(x))g(h1)(x),g(h1)(x)⟩b(h)(x),b(h)(x)⟩(22)
第一项 ⟨ g ( h − 1 ) ( x ) , g ( h − 1 ) ( x ′ ) ⟩ \langle g^{(h-1)}(x),g^{(h-1)}(x') \rangle g(h1)(x),g(h1)(x)⟩ x x x x ′ x' x在第 h h h层的协方差。当宽度趋于无穷时, ⟨ g ( h − 1 ) ( x ) , g ( h − 1 ) ( x ′ ) ⟩ \langle g^{(h-1)}(x),g^{(h-1)}(x') \rangle g(h1)(x),g(h1)(x)⟩收敛至固定的数,这里表示为 Σ ( h − 1 ) ( x , x ′ ) \Sigma^{(h-1)}(x,x') Σ(h1)(x,x)。对于 h ∈ [ L ] h\in[L] h[L],该协方差的递归形式为
Σ ( 0 ) ( x , x ′ ) = x ⊤ x ′ Λ ( h ) ( x , x ′ ) = ( Σ ( h − 1 ) ( x , x ) Σ ( h − 1 ) ( x , x ′ ) Σ ( h − 1 ) ( x ′ , x ) Σ ( h − 1 ) ( x ′ , x ′ ) ) ∈ R 2 × 2 Σ ( h ) ( x , x ′ ) = c σ E ( u , v ) ∼ N ( 0 , Λ ( h ) ) [ σ ( u ) σ ( v ) ] (23) \begin{align} \Sigma^{(0)}(x,x')&=x^\top x' \\ \Lambda^{(h)}(x,x')&= \begin{pmatrix} \Sigma^{(h-1)}(x,x)&\Sigma^{(h-1)}(x,x') \\ \Sigma^{(h-1)}(x',x)&\Sigma^{(h-1)}(x',x') \\ \end{pmatrix}\in\mathbb{R}^{2\times 2} \\ \Sigma^{(h)}(x,x')&=c_\sigma E_{(u,v)\sim\mathcal{N}(0,\Lambda^{(h)})}[\sigma(u)\sigma(v)] \end{align}\tag{23} \\ Σ(0)(x,x)Λ(h)(x,x)Σ(h)(x,x)=xx=(Σ(h1)(x,x)Σ(h1)(x,x)Σ(h1)(x,x)Σ(h1)(x,x))R2×2=cσE(u,v)N(0,Λ(h))[σ(u)σ(v)](23)

你可能感兴趣的:(自然语言处理,深度学习,神经正切核,理论,NTK)