Batch Normalization论文解读

BN原理分析

    • 前言
      • 为什么会产生梯度消失和梯度爆炸?
      • 梯度消失
      • 梯度爆炸
    • 提出背景
      • 什么是Internal Covariate Shift
      • Internal Covariate Shift会带来什么问题
      • 如何缓解Internal Covariate Shift
    • Batch Normalization 整体思路
      • 思路
      • 算法
    • inference阶段BN的使用
    • BN的优势

BN这篇论文,我看了很久,是时候来归纳总结一下论文的整体思路与原理,顺便巩固与加深对BN的理解~

文章主要会从四个方面对Batch Normalization进行详解:

  • 提出背景
  • BN的整体思路
  • inference阶段BN的使用
  • BN的优势

文章主要参考了论文以及其他博主的资料。所有参考链接均见文章参考链接部分。

前言

在开始论文的讲解之前,先抛出两个概念梯度消失梯度爆炸

为什么会产生梯度消失和梯度爆炸?

目前优化神经网络都是基于反向传播,即根据损失函数的计算通过梯度反向传播的方式,使得网络权值更新优化,而这个反向传播的过程用到了链式法则反向传播算法可以说是梯度下降在链式法则中的应用。

链式法则是一个连乘式,当层数越深时,梯度也会以指数形式传播。在跟新网络权值时,得到的梯度接近0特别大,也就是梯度消失爆炸。梯度消失与梯度爆炸其实在本质上是一样的。

梯度消失

深层网络或是选择了不合适的激活函数,例如sigmoid。梯度消失发生时,接近输出层的隐藏层权值更新相对正常,但是当越靠近输入层时,会导致靠近输入层的隐藏层权值更新缓慢或更新停滞。这相当于训练时,只等价于后面几层的浅层网络的学习。

梯度爆炸

一般出现在深层网络权值初始值太大的情况下,在深层神经网络或循环神经网络中,误差的梯度可在更新中累计相乘。如果网络层之间的梯度值大于1,那么重复相乘会导致梯度呈指数级增长,梯度变的非常大,然后导致网络权重的大幅更新,并因此使网络变得不稳定

梯度爆炸会伴随一些细微的信号,如:

  1. 模型不稳定,导致更新过程中的损失出现显著变化
  2. 训练过程中权值变得非常大,以至于溢出,导致模型损失变成NaN

提出背景

首先解读论文标题《Batch Normalization: Acclerating Deep Network Training by Reducing Internal Covariate Shift》:通过减少内部协变量偏移来加速深度网络的训练。所以首先需要搞清楚内部协变量偏移是什么!

什么是Internal Covariate Shift

原论文作者在文中给的定义:在深层网络训练的过程中,由于网络中参数变化而引起内部结点数据分布发生变化的这一过程被称作Internal Covariate Shift。

其实通俗理解就是:对于深度学习这种包含很多隐层的网络结构,在训练过程中,由于每一层的参数会被更新,而每一层的输出都会作为下一层的输入,意味着下一层要去不停适应这种数据分布的变化,这一过程就叫做Internal Covariate Shift。

Internal Covariate Shift会带来什么问题

(1)上层网络需要不停调整来适应输入数据分布的变化,导致网络学习速度的降低

由于后层网络需要不停去适应输入数据分布的变化,会使得整个网络的学习速率过慢。

(2)网络的训练过程容易陷入梯度饱和区,减缓网络收敛速度

当我们使用饱和激活函数(saturated activation function)时,例如sigmoid,tanh激活函数,很容易使得模型训练陷入梯度饱和区(saturated regime)。随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者变动,之所以训练收敛慢,一般是整体分布逐渐往非线性函数的取值区间的上下限两端靠近(对于Sigmoid函数来说,意味着激活输入值WU+B是大的负值或正值,即函数最平缓的部位),所以这导致反向传播时低层神经网络的梯度消失。

对于激活函数梯度饱和问题,有两种解决思路,第一种是用非饱和激活函数,例如线性整流函数(ReLU)可以在一定程度上解决训练进入梯度饱和区的问题。另一种思路是,我们可以让激活函数的输入分布保持在一个稳定状态来尽可能避免它们陷入梯度饱和区,这也是Normalization的思路。

如何缓解Internal Covariate Shift

要缓解这个问题,其实就是缓解每一层输入值分布的不停变化,因此可以固定每一层网络输入值的分布来缓解ICS问题。

(1)白化(Whitening)

白化是BN的启发来源——所谓白化就是对输入数据分布变换到0均值,单位方差的正态分布。在Lecun 98年发表的论文里验证了如果在图像处理中对输入图像进行白化(Whiten)操作的话,神经网络会较快收敛。因此BN的作者也作了类似的推论:图像是深度神经网络的输入层,做白化能加快收敛,那么对于深度网络来说,任何隐层的输出都是作为下一层的输入,即每一个隐层都可以看作是输入层,不过是相对下一层来说而已,那么能不能对每个隐层都做白化呢?于是作者提出了BN的思想。

(2)Batch Normalization的提出

既然白化可以解决这个问题,为什么还要提出别的解决办法,因为白化存在着以下两个问题:

  • 白化过程计算成本太高:1. 白化需要拿整个训练集进行归一化 2.白化需要计算协方差矩阵 C o v [ x ] = E [ x x T ] − E [ x ] E [ x ] T Cov[x]=E[xx^T]-E[x]E[x]^T Cov[x]=E[xxT]E[x]E[x]T还要计算 C o v [ x ] − 1 2 ( x − E [ x ] ) Cov[x]^{-\frac{1}{2}}(x-E[x]) Cov[x]21(xE[x])
  • 白化过程改变了网络每一层的分布,因而改变了网络层中本身数据的表达能力。底层网络学习到的参数信息会被白化操作丢弃掉。

于是作者主要是解决白化的以上两个问题,提出了一个简化版的Batch Normalization。一方面,BN方面要能够简化计算过程;另一方面又需要经过规范化处理后让数据尽可能保留原始的表达能力。


Batch Normalization 整体思路

思路

BN算法一共做了两个简化:

  1. 单独对每个特征进行Normalization,让每个特征都有均值为0,方差为1的分布就可以。
  2. 使用随机梯度下降时,拿整个训练集进行归一化不实际,于是提出了使用mini-batch的方式,对每个mini-batch计算对应的均值与方差。

算法

由于白化操作减弱了网络中每一层输入数据的表达能力,作者又加了一个线性变换操作,让这些数据能够尽快恢复本身的表达能力。

参数定义

  • l : l: l: 网络中的层标号
  • L : L: L: 网络中的最后一层或总层数
  • d l : d_l: dl: l l l 层的维度,即神经元结点数
  • W [ l ] : W^{[l]}: W[l]: l l l 层的权重矩阵, W [ l ] ∈ R d l × d l − 1 W^{[l]}\in R^{d_l \times d_{l-1}} W[l]Rdl×dl1
  • b [ l ] : b^{[l]}: b[l]: l l l 层的偏置向量, b [ l ] ∈ R d l × 1 b^{[l]}\in R^{d_l \times 1} b[l]Rdl×1
  • Z [ l ] : Z^{[l]}: Z[l]: l l l 层的线性计算结果, Z [ l ] = W [ l ] × i n p u t + b [ l ] Z^{[l]}= W^{[l]} \times input +b^{[l]} Z[l]=W[l]×input+b[l]
  • g [ l ] ( ⋅ ) : g^{[l]}(\cdot): g[l](): l l l 层的激活函数
  • A [ l ] : A^{[l]}: A[l]: l l l 层的非线性激活结果, A [ l ] = g [ l ] ( Z [ l ] ) A^{[l]}=g^{[l]}(Z^{[l]}) A[l]=g[l](Z[l])

样本定义

  • M : M: M: 训练样本的数量
  • N : N: N: 训练样本的特征数
  • X : X: X: 训练样本集, X = { x ( 1 ) , x ( 2 ) , . . . , x ( M ) } X=\{x^{(1)},x^{(2)},...,x^{(M)}\} X={x(1),x(2),...,x(M)}, X ∈ R N × M X\in R^{N\times M} XRN×M(这里 X X X的每一列是一个样本)
  • m : m: m: batch size,即每个batch中样本的数量
  • χ ( i ) : \chi^{(i)}: χ(i): i i i 个mini-batch的训练数据, X = { χ ( 1 ) , χ ( 2 ) , . . . , χ ( k ) } X=\{\chi^{(1)},\chi^{(2)},...,\chi^{(k)}\} X={χ(1),χ(2),...,χ(k)},其中 χ ( i ) ∈ R N × m \chi^{(i)}\in R^{N\times m} χ(i)RN×m

算法流程
对每个特征进行独立的normalization,并且考虑一个batch的训练,传入m个训练样本,并关注网络中的某一层,忽略上标 l l l.

关注当前层的第 j j j 个维度,也就是第 j j j 个神经元结点,对当前维度进行规范化:

μ j = 1 m ∑ i = 1 m Z j ( i ) \mu_j =\frac{1}{m}\sum_{i=1} ^{m}Z_j^{(i)} μj=m1i=1mZj(i)

σ j 2 = 1 m ∑ i = 1 m ( Z j ( i ) − μ j ) 2 \sigma_j^2=\frac{1}{m}\sum_{i=1}^{m}(Z_j^{(i)}-\mu_j)^2 σj2=m1i=1m(Zj(i)μj)2

Z j ^ = Z j − μ j σ j 2 + ϵ \hat{Z_j}=\frac{Z_j-\mu_j}{\sqrt[]{\sigma_j^2 +\epsilon}} Zj^=σj2+ϵ Zjμj

ϵ \epsilon ϵ 是为了防止方差为0产生无效计算

结合具体的例子来看,下图均来自:https://zhuanlan.zhihu.com/p/34879333
下图我们只关注第 l l l 层的计算结果,左边的矩阵是 Z [ l ] = W [ l ] A [ l − 1 ] + b [ l ] Z^{[l]}=W^{[l]}A^{[l-1]}+b^{[l]} Z[l]=W[l]A[l1]+b[l] 线性计算结果,还未进行激活函数的非线性变换。此时每一列是一个样本,图中一共8列,代表当前训练的batch size 为8,每一行代表当前 l l l 层神经元的一个结点,可以看到当前 l l l 层共有4个神经元结点,即第 l l l 层维度为4,每行的数据分布均不同:
Batch Normalization论文解读_第1张图片
对于第一个神经元,求得 μ 1 = 1.65 , σ 1 2 = 0.44 \mu_1 =1.65,\sigma_1^2=0.44 μ1=1.65,σ12=0.44 其中( ϵ = 1 0 − 8 \epsilon=10^{-8} ϵ=108),此时我们用 μ 1 , σ 1 2 \mu_1,\sigma_1^2 μ1,σ12 对第一行数据(第一个维度)进行normalization得到新的值 [ − 0.98 , − 0.23 , − 0.68 , − 1.13 , 0.08 , 0.68 , 2.19 , 0.08 ] [-0.98,-0.23,-0.68,-1.13,0.08,0.68,2.19,0.08] [0.98,0.23,0.68,1.13,0.08,0.68,2.19,0.08]。随后我们可以计算出其他输入维度归一化后的值,如下图:
Batch Normalization论文解读_第2张图片
通过上面的变换,使得第 l l l 层的输入每个特征的分布均值为0,方差为1。

到了这一步之后,每一层网络的输入数据分布都变得稳定,但却导致了数据表达能力的缺失。这也很好理解,通过调整数据分布,会使得大部分的值落入函数的线性部分,但是我们都知道,多层的线性函数变换其实是没有意义的,因为多层线性网络跟一层线性网络是等价的。因此BN又引入了两个可学习的参数 γ \gamma γ β \beta β 。这两个参数的引入是为了恢复数据本身的表达能力,对规范化后的数据进行线性变换,即 Z j ~ = γ j Z j ^ + β j \tilde{Z_j}=\gamma_j\hat{Z_j}+\beta_j Zj~=γjZj^+βj。当 γ 2 = σ 2 , β = μ \gamma^2=\sigma^2,\beta=\mu γ2=σ2,β=μ时,可以实现等价变换(idetity transform)并且保留了原始输入特征的分布信息。

公式
对于神经网络的第 l l l 层:

Z [ l ] = W [ l ] A [ l − 1 ] + b [ l ] Z^{[l]}=W^{[l]}A^{[l-1]}+b^{[l]} Z[l]=W[l]A[l1]+b[l]

μ = 1 m ∑ i = 1 m Z [ l ] ( i ) \mu =\frac{1}{m}\sum_{i=1}^mZ^{[l](i)} μ=m1i=1mZ[l](i)

σ 2 = 1 m ∑ i = 1 m ( Z [ l ] ( i ) − μ ) 2 \sigma^2=\frac{1}{m}\sum_{i=1}^m(Z^{[l](i)}-\mu)^2 σ2=m1i=1m(Z[l](i)μ)2

Z ~ [ l ] = γ ⋅ Z [ l ] − μ σ 2 + ϵ + β \tilde{Z}^{[l]}=\gamma \cdot \frac{Z^{[l]}-\mu}{\sqrt[]{\sigma^2+\epsilon}}+\beta Z~[l]=γσ2+ϵ Z[l]μ+β

A [ l ] = g [ l ] ( Z ~ [ l ] ) A^{[l]}=g^{[l]}(\tilde{Z}^{[l]}) A[l]=g[l](Z~[l])


inference阶段BN的使用

在推理(inference)阶段,输入有时就只有一个样本或很少的样本,也就是此时很有可能batch_size = 1,无法求出均值和方差,如何获取 μ \mu μ σ 2 \sigma^2 σ2呢?

利用BN训练好模型后,保留每组mini-batch训练数据在网络中每一层的 μ b a t c h \mu_{batch} μbatch σ b a t c h 2 \sigma_{batch}^2 σbatch2。此时我们使用整个样本的统计量来对Test数据进行归一化,具体来说是使用均值与方差的无偏估计:

μ t e s t = E ( μ b a t c h ) \mu_{test}=E(\mu_{batch}) μtest=E(μbatch)

σ t e s t 2 = m m − 1 E ( σ b a t c h 2 ) \sigma_{test}^2=\frac{m}{m-1}E(\sigma_{batch}^2) σtest2=m1mE(σbatch2)

得到每个特征的均值与方差的无偏估计后,我们对test数据采用同样的normalization方法:

B N ( X t e s t ) = γ ⋅ X t e s t − μ t e s t σ t e s t 2 + ϵ + β BN(X_{test})=\gamma\cdot\frac{X_{test}-\mu_{test}}{\sqrt[]{\sigma_{test}^2}+\epsilon}+\beta BN(Xtest)=γσtest2 +ϵXtestμtest+β


BN的优势

(1)可以选择比较大的初始学习率,对网络中的参数不那么敏感,简化调参过程

在网络训练时,通常会谨慎地采用一些权重初始化方法或者合适的学习率来保证网络稳定训练。

当学习率设置太高时,会使得参数更新步伐过大,容易出现震荡和不收敛。而且初始权值过大容易导致梯度消失,然而使用BN网络将不会受到参数数值大小的影响。例如,对参数 W W W 进行缩放的得到 a W aW aW ,对于缩放前的值 W μ W\mu Wμ,我们设其均值为 μ 1 \mu_1 μ1,方差为 σ 1 2 \sigma_1^2 σ12;对于缩放值 a W μ aW\mu aWμ,设其均值为 μ 2 \mu_2 μ2,方差为 σ 2 2 \sigma_2^2 σ22 ,则我们有:

μ 2 = a μ 1 \mu_2 =a\mu_1 μ2=aμ1 , σ 2 2 = a 2 σ 1 2 \sigma_2^2=a^2\sigma_1^2 σ22=a2σ12

此时忽略 ϵ \epsilon ϵ ,则有:

B N ( a W μ ) = γ ⋅ a W μ − μ 2 σ 2 2 + β = γ ⋅ a W μ − a μ 1 a 2 σ 1 2 + β = γ ⋅ W μ − μ 1 σ 1 2 + β = B N ( W μ ) BN(aW\mu)=\gamma\cdot\frac{aW\mu-\mu_2}{\sqrt[]{\sigma_2^2}}+\beta=\gamma\cdot\frac{aW\mu-a\mu_1}{\sqrt[]{a^2\sigma_1^2}}+\beta=\gamma\cdot\frac{W\mu-\mu_1}{\sqrt[]{\sigma_1^2}}+\beta=BN(W\mu) BN(aWμ)=γσ22 aWμμ2+β=γa2σ12 aWμaμ1+β=γσ12 Wμμ1+β=BN(Wμ)

∂ B N ( ( a W ) μ ) ∂ μ = γ ⋅ a W σ 2 2 = γ ⋅ a W a 2 σ 1 2 = ∂ B N ( W μ ) ∂ μ \frac{\partial{BN((aW)\mu)}}{\partial\mu}=\gamma\cdot\frac{aW}{\sqrt[]{\sigma_2^2}}=\gamma\cdot\frac{aW}{\sqrt[]{a^2\sigma_1^2}}=\frac{\partial{BN(W\mu)}}{\partial\mu} μBN((aW)μ)=γσ22 aW=γa2σ12 aW=μBN(Wμ)

∂ B N ( ( a W ) μ ) ∂ a W = γ ⋅ μ σ 2 2 = γ ⋅ a W a ⋅ σ 1 2 = 1 2 ⋅ ∂ B N ( W μ ) ∂ W \frac{\partial{BN((aW)\mu)}}{\partial{aW}}=\gamma\cdot\frac{\mu}{\sqrt[]{\sigma_2^2}}=\gamma\cdot\frac{aW}{a\cdot\sqrt[]{\sigma_1^2}}=\frac{1}{2}\cdot\frac{\partial{BN(W\mu)}}{\partial{W}} aWBN((aW)μ)=γσ22 μ=γaσ12 aW=21WBN(Wμ)

可以看到,经过BN操作以后,权重的缩放并不会影响到对 μ \mu μ 的梯度计算;并且当权重越大时,即 a a a 越大, 1 a \frac{1}{a} a1 越小,意味着权重 W W W 的梯度反而越小,这样BN就保证了梯度不会依赖于参数的scale,使得参数的更新处在更加稳定的状态.

(2)起到了一定的正则效果

作者证实了网络加入BN之后,可以丢弃dropout,并且模型一样具有很好的泛化效果.

(3)BN允许网络使用饱和性激活函数(例如sigmoid,tanh等),缓解梯度消失问题

通过normalize操作可以让激活函数的输入数据落在梯度非饱和区,缓解梯度消失的问题;另外通过自适应学习 γ \gamma γ β \beta β 又让数据保留更多的原始信息。

(4)BN使得网络中每层输入数据的分布相对稳定,加速模型学习速度

BN通过规范化与线性变换使得每一层网络的输入数据的均值与方差都在一定范围内,使得后一层网络不必不断去适应底层网络中输入的变化,允许每一层进行独立学习,有利于提高整个神经网络的学习速度。


最后:

B N ( W μ + b ) = B N ( W μ ) BN(W\mu+b)=BN(W\mu) BN(Wμ+b)=BN(Wμ),因此b可以被忽略掉或可以被置为0(在分子减去均值的时候会被减去).


参考文献与博主链接:

[1] Ioffe S, Szegedy C. Batch normalization: accelerating deep network training by reducing internal covariate shift[C]// International Conference on International Conference on Machine Learning. JMLR.org, 2015:448-456. https://arxiv.org/pdf/1502.03167.pdf
[2] 知乎博主专栏:https://zhuanlan.zhihu.com/p/34879333
[3] 梯度消失与梯度爆炸:https://zhuanlan.zhihu.com/p/72589432

你可能感兴趣的:(深度学习,深度学习)