深入理解批标准化(Batch Normalization)

文章目录

  • 0 前言
  • 1 “Internal Covariate Shift”问题
    • 1.1 什么是“Internal Covariate Shift”
  • 2 Batch Norm的本质思想
    • 2.1 本质思想
    • 2.2 将激活输入调整为N(0,1)有何用?
    • 2.3 存在的一个问题
  • 3 Batch Norm的训练过程
    • 3.1 再Mini-batch SGD下做BN操作
  • 4 Batch Norm的预测过程
    • 4.1 预测中存在的问题
    • 4.2 解决该问题
      • 4.2.1 方法一
      • 4.2.2 方法二(常用)
  • 参考文献

0 前言

  • Batch Normalization作为最近一年来DL的重要成果,已经广泛被证明其有效性和重要性。虽然有些细节处理还解释不清其理论原因,但是实践证明好用才是真的好,别忘了DL从Hinton对深层网络做Pre-Train开始就是一个经验领先于理论分析的偏经验的一门学问。
  • 本文参考《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》.
  • I.I.D问题
    机器学习领域有个很重要的假设:IID独立同分布假设,就是假设训练数据和测试数据是满足相同分布的,这是通过训练数据获得的模型能够在测试集获得好的效果的一个基本保障。
  • BatchNorm的作用是什么呢?
    BatchNorm就是在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布的。
  • 引入深度学习中常见的问题
    为什么深度神经网络随着网络深度加深,训练起来越困难,收敛越来越慢?很多论文都是解决这个问题的,比如ReLU激活函数,再比如Residual Network,BN本质上也是解释并从某个不同的角度来解决这个问题的。
  • BN用来解决梯度爆炸和梯度消失

1 “Internal Covariate Shift”问题

1.1 什么是“Internal Covariate Shift”

从论文名字可以看出,BN是用来解决“Internal Covariate Shift”问题的,那么首先得理解什么是“Internal Covariate Shift”?

论文首先说明Mini-Batch SGD相对于One Example SGD的两个优势:
梯度更新方向更准确;
并行计算速度快;
(为什么要说这些?因为BatchNorm是基于Mini-Batch SGD的,所以先夸下Mini-Batch SGD,当然也是大实话);然后吐槽下SGD训练的缺点:超参数调起来很麻烦。(作者隐含意思是用BN就能解决很多SGD的缺点)

  • 引入“Internal Covariate Shift”的概念(重点阅读)
    如果ML系统实例集合中的输入值X的分布老是变,这不符合IID假设,网络模型很难稳定的学规律,这不得引入迁移学习才能搞定吗,我们的ML系统还得去学习怎么迎合这种分布变化啊。对于深度学习这种包含很多隐层的网络结构,在训练过程中,因为各层参数不停在变化,所以每个隐层都会面临covariate shift的问题,也就是在训练过程中,隐层的输入分布老是变来变去,这就是所谓的“Internal Covariate Shift”,Internal指的是深层网络的隐层,是发生在网络内部的事情,而不是covariate shift问题只发生在输入层。

  • 举一个分布不一致的例子(BY Andrew Ng)
    这里输入的分布不一致(左边和右边)

  • 引入Batch Norm
    能不能让每个隐层节点的激活输入分布固定下来呢?这样就避免了“Internal Covariate Shift”问题了。

  • BN的启发来源
    启发来源的:之前的研究表明如果在图像处理中对输入图像进行白化(Whiten)操作的话——所谓白化,就是对输入数据分布变换到0均值,单位方差的正态分布——那么神经网络会较快收敛,那么BN作者就开始推论了:图像是深度神经网络的输入层,做白化能加快收敛,那么其实对于深度网络来说,其中某个隐层的神经元是下一层的输入,意思是其实深度神经网络的每一个隐层都是输入层,不过是相对下一层来说而已,那么能不能对每个隐层都做白化呢?这就是启发BN产生的原初想法,而BN也确实就是这么做的,可以理解为对深层神经网络每个隐层神经元的激活值做简化版本的白化操作。

2 Batch Norm的本质思想

2.1 本质思想

  • BN的基本思想其实相当直观:因为深层神经网络在做非线性变换前的激活输入值(就是那个x=WU+B,U是输入)随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者变动,之所以训练收敛慢,一般是整体分布逐渐往非线性函数的取值区间的上下限两端靠近(对于Sigmoid函数来说,意味着激活输入值WU+B是大的负值或正值),所以这导致反向传播时低层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因,而BN就是通过一定的规范化手段,把每层神经网络任意神经元这个输入值的分布强行拉回到均值为0方差为1的标准正态分布,其实就是把越来越偏的分布强制拉回比较标准的分布,这样使得激活输入值落在非线性函数对输入比较敏感的区域,这样输入的小变化就会导致损失函数较大的变化,意思是这样让梯度变大,避免梯度消失问题产生,而且梯度变大意味着学习收敛速度快,能大大加快训练速度。
  • 总结一下
    对于每个隐层神经元,把逐渐向非线性函数映射后向取值区间极限饱和区靠拢的输入分布强制拉回到均值为0方差为1的比较标准的正态分布,使得非线性变换函数的输入值落入对输入比较敏感的区域,以此避免梯度消失问题。因为梯度一直都能保持比较大的状态,所以很明显对神经网络的参数调整效率比较高,就是变动大,就是说向损失函数最优值迈动的步子大,也就是说收敛地快。BN说到底就是这么个机制,方法很简单,道理很深刻。

2.2 将激活输入调整为N(0,1)有何用?

深入理解批标准化(Batch Normalization)_第1张图片
这意味着在一个标准差范围内,也就是说64%的概率x其值落在[-1,1]的范围内,在两个标准差范围内,也就是说95%的概率x其值落在了[-2,2]的范围内。那么这又意味着什么?我们知道,激活值x=WU+B,U是真正的输入,x是某个神经元的激活值,假设非线性函数是sigmoid,那么看下sigmoid(x)其图形:
深入理解批标准化(Batch Normalization)_第2张图片
及sigmoid(x)的导数为:G’=f(x)*(1-f(x)),因为f(x)=sigmoid(x)在0到1之间,所以G’在0到0.25之间,其对应的图如下:
深入理解批标准化(Batch Normalization)_第3张图片

  • 假设没有经过BN调整前x的原先正态分布均值是-6,方差是1,那么意味着95%的值落在了[-8,-4]之间,那么对应的Sigmoid(x)函数的值明显接近于0,这是典型的梯度饱和区,在这个区域里梯度变化很慢,为什么是梯度饱和区?请看下sigmoid(x)如果取值接近0或者接近于1的时候对应导数函数取值,接近于0,意味着梯度变化很小甚至消失。而假设经过BN后,均值是0,方差是1,那么意味着95%的x值落在了[-2,2]区间内,很明显这一段是sigmoid(x)函数接近于线性变换的区域,意味着x的小变化会导致非线性函数值较大的变化,也即是梯度变化较大,对应导数函数图中明显大于0的区域,就是梯度非饱和区。
  • 从上面几个图应该看出来BN在干什么了吧?其实就是把隐层神经元激活输入x=WU+B从变化不拘一格的正态分布通过BN操作拉回到了均值为0,方差为1的正态分布,即原始正态分布中心左移或者右移到以0为均值,拉伸或者缩减形态形成以1为方差的图形。什么意思?就是说经过BN后,目前大部分Activation的值落入非线性函数的线性区内,其对应的导数远离导数饱和区,这样来加速训练收敛过程

2.3 存在的一个问题

如果都通过BN,那么不就跟把非线性函数替换成线性函数效果相同了?这意味着什么?我们知道,如果是多层的线性函数变换其实这个深层是没有意义的,因为多层线性网络跟一层线性网络是等价的。这意味着网络的表达能力下降了,这也意味着深度的意义就没有了。所以BN为了保证非线性的获得,对变换后的满足均值为0方差为1的x又进行了scale加上shift操作(y=scale*x+shift),每个神经元增加了两个参数scaleshift参数,这两个参数是通过训练学习到的,意思是通过scale和shift把这个值从标准正态分布左移或者右移一点并长胖一点或者变瘦一点,每个实例挪动的程度不一样,这样等价于非线性函数的值从正中心周围的线性区往非线性区动了动。核心思想应该是想找到一个线性和非线性的较好平衡点,既能享受非线性的较强表达能力的好处,又避免太靠非线性区两头使得网络收敛速度太慢。当然,这是我的理解,论文作者并未明确这样说。但是很明显这里的scale和shift操作是会有争议的,因为按照论文作者论文里写的理想状态,就会又通过scale和shift操作把变换后的x调整回未变换的状态,那不是饶了一圈又绕回去原始的“Internal Covariate Shift”问题里去了吗,感觉论文作者并未能够清楚地解释scale和shift操作的理论原因。

3 Batch Norm的训练过程

3.1 再Mini-batch SGD下做BN操作

  • BN操作在哪里进行?
    在未经过激活函数的输入和通过激活函数的中间放一个BN层。(这里一些论文有一些争议,就是有的人会把BN层放在激活函数之后使用,但是一般的,都是放在中间)
  • BN操作
    深入理解批标准化(Batch Normalization)_第4张图片
    对于Mini-Batch SGD来说,一次训练过程里面包含m个训练实例,其具体BN操作就是对于隐层内每个神经元的激活值来说,进行如下变换:
    深入理解批标准化(Batch Normalization)_第5张图片
    变换的意思是:某个神经元对应的原始的激活x通过减去mini-Batch内m个实例获得的m个激活x求得的均值E(x)并除以求得的方差Var(x)来进行转换。
  • 加入scale和shift超参数
    上文说过经过这个变换后某个神经元的激活x形成了均值为0,方差为1的正态分布,目的是把值往后续要进行的非线性变换的线性区拉动,增大导数值,增强反向传播信息流动性,加快训练收敛速度。但是这样会导致网络表达能力下降,为了防止这一点,每个神经元增加两个调节参数(scale和shift),这两个参数是通过训练来学习到的,用来对变换后的激活反变换,使得网络表达能力增强,即对变换后的激活进行如下的scale和shift操作,这其实是变换的反操作
    在这里插入图片描述
  • 具体算法
    深入理解批标准化(Batch Normalization)_第6张图片
  • 代码
    tf.keras.layers.BatchNormalization()

4 Batch Norm的预测过程

4.1 预测中存在的问题

  • 在预测时,当我们的输入只有一个实例时,应如何选取方差和均值?

4.2 解决该问题

4.2.1 方法一

利用全局统计量

  • 很简单,因为每次做Mini-Batch训练时,都会有那个Mini-Batch里m个训练实例获得的均值和方差,现在要全局统计量,只要把每个Mini-Batch的均值和方差统计量记住,然后对这些均值和方差求其对应的数学期望即可得出全局统计量,即:
    深入理解批标准化(Batch Normalization)_第7张图片
  • 有了均值和方差,每个隐层神经元也已经有对应训练好的Scaling参数和Shift参数,就可以在推导的时候对每个神经元的激活数据计算BN进行变换了,:
    深入理解批标准化(Batch Normalization)_第8张图片

4.2.2 方法二(常用)

指数加权平均(在实践中常用这个方法)

  • 训练时,利用不同batch中对应的量进行指数加权平均计算我们的统计量均值和方差。
  • 在深度学习框架里一般会有默认的方法,不需要调整。
    深入理解批标准化(Batch Normalization)_第9张图片

参考文献

参考文献1

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

你可能感兴趣的:(深度学习,优化算法)