其实关于BN层,我在之前的文章“梯度爆炸”那一篇中已经涉及到了,但是鉴于面试经历中多次问道这个,这里再做一个更加全面的讲解。
Batch Normalization的原论文作者给了Internal Covariate Shift一个较规范的定义:在深层网络训练的过程中,由于网络中参数变化而引起内部结点数据分布发生变化的这一过程被称作Internal Covariate Shift。
这里做一个简单的数学定义,对于全链接网络而言,第i层的数学表达可以体现为:
Z i = W i × i n p u t i + b i Z^i=W^i\times input^i+b^i Zi=Wi×inputi+bi
i n p u t i + 1 = g i ( Z i ) input^{i+1}=g^i(Z^i) inputi+1=gi(Zi)
【怎么理解ICS问题】
我们知道,随着梯度下降的进行,每一层的参数 W i , b i W^i,b^i Wi,bi都会不断地更新,这意味着 Z i Z^i Zi的分布也不断地改变,从而 i n p u t i + 1 input^{i+1} inputi+1的分布发生了改变。这意味着,除了第一层的输入数据不改变,之后所有层的输入数据的分布都会随着模型参数的更新发生改变,而每一层就要不停的去适应这种数据分布的变化,这个过程就是Internal Covariate Shift。
【ICS带来的收敛速度慢】
因为每一层的参数不断发生变化,从而每一层的计算结果的分布发生变化,后层网络不断地适应这种分布变化,这个时候会让整个网络的学习速度过慢。
【梯度饱和问题】
因为神经网络中经常会采用sigmoid,tanh这样的饱和激活函数(saturated actication function),因此模型训练有陷入梯度饱和区的风险。解决这样的梯度饱和问题有两个思路:第一种就是更为非饱和性激活函数,例如线性整流函数ReLU可以在一定程度上解决训练进入梯度饱和区的问题。另一种思路是,我们可以让激活函数的输入分布保持在一个稳定状态来尽可能避免它们陷入梯度饱和区,这也就是Normalization的思路。
batchNormalization就像是名字一样,对一个batch的数据进行normalization。
现在假设一个batch有3个数据,每个数据有两个特征:(1,2),(2,3),(0,1)
如果做一个简单的normalization,那么就是计算均值和方差,把数据减去均值除以标准差,变成0均值1方差的标准形式。
对于第一个特征来说:
μ = 1 3 ( 1 + 2 + 0 ) = 1 \mu=\frac{1}{3}(1+2+0)=1 μ=31(1+2+0)=1
σ 2 = 1 3 ( ( 1 − 1 ) 2 + ( 2 − 1 ) 2 + ( 0 − 1 ) 2 ) = 0.67 \sigma^2=\frac{1}{3}((1-1)^2+(2-1)^2+(0-1)^2)=0.67 σ2=31((1−1)2+(2−1)2+(0−1)2)=0.67
【通用公式】
μ = 1 m ∑ i = 1 m Z \mu=\frac{1}{m}\sum_{i=1}^m{Z} μ=m1∑i=1mZ
σ 2 = 1 m ∑ i = 1 m ( Z − μ ) \sigma^2=\frac{1}{m}\sum_{i=1}^m(Z-\mu) σ2=m1∑i=1m(Z−μ)
Z ^ = Z − μ σ 2 + ϵ \hat{Z}=\frac{Z-\mu}{\sqrt{\sigma^2+\epsilon}} Z^=σ2+ϵZ−μ
目前为止,我们做到了让每个特征的分布均值为0,方差为1。这样分布都一样,一定不会有ICS问题
如同上面提到的,Normalization操作我们虽然缓解了ICS问题,让每一层网络的输入数据分布都变得稳定,但却导致了数据表达能力的缺失。每一层的分布都相同,所有任务的数据分布都相同,模型学啥呢
【0均值1方差数据的弊端】
为了解决这个问题,BN层引入了两个可学习的参数 γ \gamma γ和 β \beta β,这样,经过BN层normalization的数据其实是服从 β \beta β均值, γ 2 \gamma^2 γ2方差的数据。
所以对于某一层的网络来说,我们现在变成这样的流程:
(上面公式中,省略了 i i i,总的来说是表示第i层的网络层产生第i+1层输入数据的过程)
我们知道BN在每一层计算的 μ \mu μ与 σ 2 \sigma^2 σ2 都是基于当前batch中的训练数据,但是这就带来了一个问题:我们在预测阶段,有可能只需要预测一个样本或很少的样本,没有像训练样本中那么多的数据,这样的 σ 2 \sigma^2 σ2和 μ \mu μ要怎么计算呢?
利用训练集训练好模型之后,其实每一层的BN层都保留下了每一个batch算出来的 μ \mu μ和 σ 2 \sigma^2 σ2.然后呢利用整体的训练集来估计测试集的 μ t e s t \mu_{test} μtest和 σ t e s t 2 \sigma_{test}^2 σtest2
μ t e s t = E ( μ t r a i n ) \mu_{test}=E(\mu_{train}) μtest=E(μtrain)
σ t e s t 2 = m m − 1 E ( σ t r a i n 2 ) \sigma_{test}^2=\frac{m}{m-1}E(\sigma_{train}^2) σtest2=m−1mE(σtrain2)
然后再对测试机进行BN层:
当然,计算训练集的 μ \mu μ和KaTeX parse error: Undefined control sequence: \simga at position 1: \̲s̲i̲m̲g̲a̲的方法除了上面的求均值之外。吴恩达老师在其课程中也提出了,可以使用指数加权平均的方法。不过都是同样的道理,根据整个训练集来估计测试机的均值方差。
BN使得网络中每层输入数据的分布相对稳定,加速模型学习速度。
BN通过规范化与线性变换使得每一层网络的输入数据的均值与方差都在一定范围内,使得后一层网络不必不断去适应底层网络中输入的变化,从而实现了网络中层与层之间的解耦,允许每一层进行独立学习,有利于提高整个神经网络的学习速度。
BN允许网络使用饱和性激活函数(例如sigmoid,tanh等),缓解梯度消失问题
通过normalize操作可以让激活函数的输入数据落在梯度非饱和区,缓解梯度消失的问题;另外通过自适应学习 γ \gamma γ与 β \beta β 又让数据保留更多的原始信息。
BN具有一定的正则化效果
在Batch Normalization中,由于我们使用mini-batch的均值与方差作为对整体训练样本均值与方差的估计,尽管每一个batch中的数据都是从总体样本中抽样得到,但不同mini-batch的均值与方差会有所不同,这就为网络的学习过程中增加了随机噪音
【weight normalization】
Weight Normalization是对网络权值进行normalization,也就是L2 norm。
相对于BN有下面的优势:
但是WN要特别注意参数初始化的选择。
【Layer normalization】
更常见的比较是BN与LN的比较。
BN层有两个缺点:
但是,在CNN中LN并没有取得比BN更好的效果。
参考链接: