Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

机器学习领域有个很重要的假设:IID独立同分布假设,就是假设训练数据和测试数据是满足相同分布的,这是通过训练数据获得的模型能够在测试集获得好的效果的一个基本保障。BatchNorm就是在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布的。

Batch Normalization的提出是为了解决随着网络深度加深,训练起来越困难,收敛越来越慢的问题。

为什么深度神经网络随着网络深度加深,训练起来越困难,收敛越来越慢?这是个在DL领域很接近本质的好问题。很多论文都是解决这个问题的,比如ReLU激活函数,再比如Residual Network。

  • 1、Internal Covariate Shift

BatchNorm是基于Mini-Batch SGD的,Mini-Batch SGD相对于One Example SGD的两个优势:梯度更新方向更准确;并行计算速度快。使用SGD训练有超参数调参的困难。

什么叫covariate shift?如果ML系统实例集合中的输入值X的分布老是变,不符合IID假设,ML系统还得去学习怎么迎合这种分布变化。对于深度学习这种包含很多隐藏层的网络结构,在训练过程中,因为各层参数老在变,所以每个隐藏层都会面临covariate shift的问题,也就是在训练过程中,隐藏层的输入分布老是变,Internal指的是深层网络的隐藏层,是发生在网络内部的事情,而不是covariate shift问题只发生在输入层。

引出BatchNorm的基本思想:让让每个隐藏层节点的激活函数的输入分布固定下来,避免了Internal Covariate Shift。

BN的启发来源:之前的研究表明如果在图像处理中对输入图像进行白化(Whiten)操作的话,即对输入数据分布变换到0均值,单位方差的正态分布,神经网络会较快收敛。图像是深度神经网络的输入层,做白化能加快收敛,那么其实对于深度网络来说,其中某个隐藏层的神经元是下一层的输入,相对下一层来说也作为输入层,那么能不能对每个隐藏层做白化呢?这就是启发BN产生的原初想法,而BN也确实就是这么做的,可以理解为对深层神经网络每个隐层神经元的激活值做简化版本的白化操作。

  • 2、BatchNorm的本质思想

随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者变动,之所以训练收敛慢,一般是整体分布逐渐往非线性函数的取值区间的上下限两端靠近(对于Sigmoid函数来说,意味着激活输入值是两端大的负值或正值),所以这导致后向传播时低层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因。

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift_第1张图片

假设某个隐层神经元原先的激活输入x取值符合正态分布,正态分布均值是-2,方差是0.5,对应上图中最左端的浅蓝色曲线,通过BN后转换为均值为0,方差是1的正态分布(对应上图中的深蓝色图形),输入x的取值正态分布整体右移2(均值的变化),图形曲线更平缓了(方差增大的变化)。这个图的意思是,BN其实就是把每个隐层神经元的激活输入分布从偏离均值为0方差为1的正态分布通过平移均值压缩或者扩大曲线尖锐程度,调整为均值为0方差为1的正态分布。那么把激活输入x调整到这个正态分布有什么用?
Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift_第2张图片

标准正态分布64%的概率x其值落在[-1,1]的范围内,95%的概率x其值落在了[-2,2]的范围内。假设非线性激活函数是sigmoid, sigmoid(x)其图形如下:

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift_第3张图片

sigmoid(x)的导数为:G’=f(x)*(1-f(x)),因为f(x)=sigmoid(x)在0到1之间,所以G’在0到0.25之间,其对应的图如下:

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift_第4张图片

假设没有经过BN调整前x的原先正态分布均值是-6,方差是1,那么意味着95%的值落在了[-8,-4]之间,那么对应的Sigmoid(x)函数的值明显接近于0,对于导数取值为梯度饱和区,意味着梯度变化很小甚至消失。而假设经过BN后,均值是0方差是1,那么意味着95%的x值落在了[-2,2]区间内,很明显这一段是sigmoid(x)函数接近于线性变换的区域,意味着x的小变化会导致非线性函数值较大的变化,也梯度变化较大,对应导数函数图中明显大于0的区域,就是梯度非饱和区。

经过BN后,目前大部分Activation的值落入非线性函数的线性区对输入比较敏感的区域内,其对应的导数远离导数饱和区,这样输入的小变化就会导致损失函数较大的变化,意思是这样让梯度变大,避免梯度消失问题产生,而且梯度变大意味着学习收敛速度快,能大大加快训练速度。因为梯度一直都能保持比较大的状态,所以很明显对神经网络的参数调整效率比较高,就是变动大,就是说向损失函数最优值迈动的步子大,也就是说收敛地快。

如果都通过BN,那么不就跟把非线性函数替换成线性函数效果相同了?多层的线性函数变换是没有意义的,因为多层线性网络跟一层线性网络是等价的,反复使用矩阵乘以输入,相当于矩阵乘法,多个矩阵相乘还是一个矩阵。线性函数是只拥有一个变量的一阶多项式函数,斜截式直线,非线性函数不是一条直线,是复合函数。激活函数是非线性的可导的,常用的激活函数有sigmoid、tanh、relu、softmax。

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift_第5张图片

BN让输入分布均值为0方差为1在经过激活函数时,数据表达能力缺失,为了恢复数据表达能力,保证非线性的获得,对变换后的满足均值为0方差为1的x又进行了scale加上shift操作进行线性变换(y=scale*x+shift),每个神经元增加了两个参数scale和shift参数,这两个参数是通过训练学习到的,意思是通过scale和shift把这个值从标准正态分布左移或者右移一点并变宽一点或者变瘦一点,每个实例挪动的程度不一样,这样等价于非线性函数的值从正中心周围的线性区往非线性区动了动。核心思想应该是想找到一个线性和非线性的较好平衡点,既能享受非线性的较强表达能力的好处,又避免太靠非线性区两头使得网络收敛速度太慢。

  • 3、训练阶段如何做BatchNorm

要对每个隐层神经元的激活值做BN,可以想象成每个隐层又加上了一层BN操作层,其图示如下:

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift_第6张图片

对于Mini-Batch SGD来说,一次训练过程里面包含m个训练实例,其具体BN操作就是对于隐层内每个神经元的激活值来说,进行如下变换:

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift_第7张图片

某个神经元对应的原始的激活x通过减去mini-Batch内m个实例获得的m个激活x求得的均值E(x)并除以求得的方差Var(x)来进行转换。

  • 4、BatchNorm的推理过程

BN在训练的时候可以根据Mini-Batch里的若干训练实例进行激活数值调整,但是在推理(inference)的过程中,很明显输入就只有一个实例,看不到Mini-Batch其它实例,那么这时候怎么对输入做BN呢?因为很明显一个实例是没法算实例集合求出的均值和方差的。

既然没有从Mini-Batch数据里可以得到的统计量,那就想其它办法来获得这个统计量,就是均值和方差。可以用从所有训练实例中获得的统计量来代替Mini-Batch里面m个训练实例获得的均值和方差统计量,因为本来就打算用全局的统计量,只是因为计算量等太大所以才会用Mini-Batch这种简化方式的,那么在推理的时候直接用全局统计量即可。

现在要全局统计量,只要把每个Mini-Batch的均值和方差统计量记住,然后对这些均值和方差求其对应的数学期望即可得出全局统计量。

有了均值和方差,每个隐层神经元也已经有对应训练好的Scaling参数和Shift参数,就可以在推导的时候对每个神经元的激活数据计算NB进行变换了。

  • 5、BatchNorm的好处

不仅仅极大提升了训练速度,收敛过程大大加快,还能增加分类效果,一种解释是这是类似于Dropout的一种防止过拟合的正则化表达方式,所以不用Dropout也能达到相当的效果。另外调参过程也简单多了,对于初始化要求没那么高,而且可以使用大的学习率等。总而言之,经过这么简单的变换,带来的好处多得很,这也是为何现在BN这么快流行起来的原因。

你可能感兴趣的:(大数据与人工智能)