论文提出NF-ResNet,根据网络的实际信号传递进行分析,模拟BatchNorm在均值和方差传递上的表现,进而代替BatchNorm。论文实验和分析十分足,出来的效果也很不错。一些初始化方法的理论效果是对的,但实际使用会有偏差,论文通过实践分析发现了这一点进行补充,贯彻了实践出真知的道理
来源:晓飞的算法工程笔记 公众号
论文: Characterizing signal propagation to close the performance gap in unnormalized ResNets
BatchNorm是深度学习中核心计算组件,大部分的SOTA图像模型都使用它,主要有以下几个优点:
然而,尽管BatchNorm很好,但还是有以下缺点:
目前,很多研究开始寻找替代BatchNorm的归一化层,但这些替代层要么表现不行,要么会带来新的问题,比如增加推理的计算消耗。而另外一些研究则尝试去掉归一化层,比如初始化残差分支的权值,使其输出为零,保证训练初期大部分的信息通过skip path进行传递。虽然能够训练很深的网络,但使用简单的初始化方法的网络的准确率较差,而且这样的初始化很难用于更复杂的网络中。
因此,论文希望找出一种有效地训练不含BatchNorm的深度残差网络的方法,而且测试集性能能够媲美当前的SOTA,论文主要贡献如下:
许多研究从理论上分析ResNet的信号传播,却很少会在设计或魔改网络的时候实地验证不同层数的特征缩放情况。实际上,用任意输入进行前向推理,然后记录网络不同位置特征的统计信息,可以很直观地了解信息传播状况并尽快发现隐藏的问题,不用经历漫长的失败训练。于是,论文提出了信号传播图(Signal Propagation Plots,SPPs),输入随机高斯输入或真实训练样本,然后分别统计每个残差block输出的以下信息:
论文对常见的BN-ReLU-Conv结构和不常见的ReLU-BN-Conv结构进行了实验统计,实验的网络为600层ResNet,采用He初始化,定义residual block为 x l + 1 = f l ( x l ) + x l x_{l+1}=f_{l}(x_{l}) + x_{l} xl+1=fl(xl)+xl,从SPPs可以发现了以下现象:
假如直接去掉BatchNorm,Average Squared Channel Means和Average Channel Variance将会不断地增加,这也是深层网络难以训练的原因。所以要去掉BatchNorm,必须设法模拟BatchNorm的信号传递效果。
根据前面的SPPs,论文设计了新的redsidual block x l + 1 = x l + α f l ( x l / β l ) x_{l+1}=x_l+\alpha f_l(x_l/\beta_l) xl+1=xl+αfl(xl/βl),主要模拟BatchNorm在均值和方差上的表现,具体如下:
根据上面的设计,给定 V a r ( x 0 ) = 1 Var(x_0)=1 Var(x0)=1和 β l = V a r ( x l ) \beta_l=\sqrt{Var(x_l)} βl=Var(xl),可根据 V a r ( x l ) = V a r ( x l − 1 ) + α 2 Var(x_l)=Var(x_{l-1})+\alpha^2 Var(xl)=Var(xl−1)+α2直接计算第 l l l个residual block的输出的方差。为了模拟ResNet中的累积方差在transition block处被重置,需要将transition block的skip path的输入缩小为 x l / β l x_l/\beta_l xl/βl,保证每个stage开头的transition block输出方差满足 V a r ( x l + 1 ) = 1 + α 2 Var(x_{l+1})=1+\alpha^2 Var(xl+1)=1+α2。将上述简单缩放策略应用到残差网络并去掉BatchNorm层,就得到了Normalizer-Free ResNets(NF-ResNets)。
论文对使用He初始化的NF-ResNet进行SPPs分析,结果如图2,发现了两个比较意外的现象:
为了验证上述现象,论文将网络的ReLU去掉再进行SPPs分析。如图7所示,当去掉ReLU后,Average Channel Squared Mean接近于0,而且残差分支输出的接近1,这表明是ReLU导致了mean-shift现象。
论文也从理论的角度分析了这一现象,首先定义转化 z = W g ( x ) z=Wg(x) z=Wg(x), W W W为任意且固定的矩阵, g ( ⋅ ) g(\cdot) g(⋅)为作用于独立同分布输入 x x x上的elememt-wise激活函数,所以 g ( x ) g(x) g(x)也是独立同分布的。假设每个维度 i i i都有 E ( g ( x i ) ) = μ g \mathbb{E}(g(x_i))=\mu_g E(g(xi))=μg以及 V a r ( g ( x i ) ) = σ g 2 Var(g(x_i))=\sigma^2_g Var(g(xi))=σg2,则输出 z i = ∑ j N W i , j g ( x j ) z_i=\sum^N_jW_{i,j}g(x_j) zi=∑jNWi,jg(xj)的均值和方差为:
其中, μ w i , . \mu w_{i,.} μwi,.和 σ w i , . \sigma w_{i,.} σwi,.为 W W W的 i i i行(fan-in)的均值和方差:
当 g ( ⋅ ) g(\cdot) g(⋅)为ReLU激活函数时,则 g ( x ) ≥ 0 g(x)\ge 0 g(x)≥0,意味着后续的线性层的输入都为正均值。如果 x i ∼ N ( 0 , 1 ) x_i\sim\mathcal{N}(0,1) xi∼N(0,1),则 μ g = 1 / 2 π \mu_g=1/\sqrt{2\pi} μg=1/2π。由于 μ g > 0 \mu_g>0 μg>0,如果 μ w i \mu w_i μwi也是非零,则 z i z_i zi同样有非零均值。需要注意的是,即使 W W W从均值为零的分布中采样而来,其实际的矩阵均值肯定不会为零,所以残差分支的任意维度的输出也不会为零,随着网络深度的增加,越来越难训练。
为了消除mean-shift现象以及保证残差分支 f l ( ⋅ ) f_l(\cdot) fl(⋅)具有方差不变的特性,论文借鉴了Weight Standardization和Centered Weight Standardization,提出Scaled Weight Standardization(Scaled WS)方法,该方法对卷积层的权值重新进行如下的初始化:
μ \mu μ和 σ \sigma σ为卷积核的fan-in的均值和方差,权值 W W W初始为高斯权值, γ \gamma γ为固定常量。代入公式1可以得出,对于 z = W ^ g ( x ) z=\hat{W}g(x) z=W^g(x),有 E ( z i ) = 0 \mathbb{E}(z_i)=0 E(zi)=0,去除了mean-shift现象。另外,方差变为 V a r ( z i ) = γ 2 σ g 2 Var(z_i)=\gamma^2\sigma^2_g Var(zi)=γ2σg2, γ \gamma γ值由使用的激活函数决定,可保持方差不变。
Scaled WS训练时增加的开销很少,而且与batch数据无关,在推理的时候更是无额外开销的。另外,训练和测试时的计算逻辑保持一致,对分布式训练也很友好。从图2的SPPs曲线可以看出,加入Scaled WS的NF-ResNet-600的表现跟ReLU-BN-Conv十分相似。
最后的因素是 γ \gamma γ值的确定,保证残差分支输出的方差在初始阶段接近1。 γ \gamma γ值由网络使用的非线性激活类型决定,假设非线性的输入 x ∼ N ( 0 , 1 ) x\sim\mathcal{N}(0,1) x∼N(0,1),则ReLU输出 g ( x ) = m a x ( x , 0 ) g(x)=max(x,0) g(x)=max(x,0)相当于从方差为 σ g 2 = ( 1 / 2 ) ( 1 − ( 1 / π ) ) \sigma^2_g=(1/2)(1-(1/\pi)) σg2=(1/2)(1−(1/π))的高斯分布采样而来。由于 V a r ( W ^ g ( x ) ) = γ 2 σ g 2 Var(\hat{W}g(x))=\gamma^2\sigma^2_g Var(W^g(x))=γ2σg2,可设置 γ = 1 / σ g = 2 1 − 1 π \gamma=1/\sigma_g=\frac{\sqrt{2}}{\sqrt{1-\frac{1}{\pi}}} γ=1/σg=1−π12来保证 V a r ( W ^ g ( x ) ) = 1 Var(\hat{W}g(x))=1 Var(W^g(x))=1。虽然真实的输入不是完全符合 x ∼ N ( 0 , 1 ) x\sim \mathcal{N}(0,1) x∼N(0,1),在实践中上述的 γ \gamma γ设定依然有不错的表现。
对于其他复杂的非线性激活,如SiLU和Swish,公式推导会涉及复杂的积分,甚至推出不出来。在这种情况下,可使用数值近似的方法。先从高斯分布中采样多个 N N N维向量 x x x,计算每个向量的激活输出的实际方差 V a r ( g ( x ) ) Var(g(x)) Var(g(x)),再取实际方差均值的平方根即可。
本文的核心在于保持正确的信息传递,所以许多常见的网络结构都要进行修改。如同选择 γ \gamma γ值一样,可通过分析或实践判断必要的修改。比如SE模块 y = s i g m o i d ( M L P ( p o o l ( h ) ) ) ∗ h y=sigmoid(MLP(pool(h)))*h y=sigmoid(MLP(pool(h)))∗h,输出需要与 [ 0 , 1 ] [0,1] [0,1]的权值进行相乘,导致信息传递减弱,网络变得不稳定。使用上面提到的数值近似进行单独分析,发现期望方差为0.5,这意味着输出需要乘以2来恢复正确的信息传递。
实际上,有时相对简单的网络结构修改就可以保持很好的信息传递,而有时候即便网络结构不修改,网络本身也能够对网络结构导致的信息衰减有很好的鲁棒性。因此,论文也尝试在维持稳定训练的前提下,测试Scaled WS层的约束的最大放松程度。比如,为Scaled WS层恢复一些卷积的表达能力,加入可学习的缩放因子和偏置,分别用于权值相乘和非线性输出相加。当这些可学习参数没有任何约束时,训练的稳定性没有受到很大的影响,反而对大于150层的网络训练有一定的帮助。所以,NF-ResNet直接放松了约束,加入两个可学习参数。
论文的附录有详细的网络实现细节,有兴趣的可以去看看。
总结一下,Normalizer-Free ResNet的核心有以下几点:
对比RegNet的Normalizer-Free变种与其他方法的对比,相对于EfficientNet还是差点,但已经十分接近了。
论文提出NF-ResNet,根据网络的实际信号传递进行分析,模拟BatchNorm在均值和方差传递上的表现,进而代替BatchNorm。论文实验和分析十分足,出来的效果也很不错。一些初始化方法的理论效果是对的,但实际使用会有偏差,论文通过实践分析发现了这一点进行补充,贯彻了实践出真知的道理。
如果本文对你有帮助,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】