机器学习领域有个很重要的假设:IID独立同分布假设,就是假设训练数据和测试数据是满足相同分布的,这是通过训练数据获得的模型能够在测试集获得好的效果的一个基本保障。
那BatchNorm的作用是什么呢?BatchNorm就是在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布
深度神经网络随着网络深度加深,训练起来越困难,收敛越来越慢,很多论文都是解决这个问题的,比如ReLU激活函数,再比如Residual Network(残差网络),BN本质上也是解释并从某个不同的角度来解决这个问题的。
BN是用来解决“Internal Covariate Shift”问题的,Mini-Batch SGD相对于One Example SGD的两个优势:梯度更新方向更准确;并行计算速度快,BatchNorm是基于Mini-Batch SGD的。
covariate shift的概念:如果ML系统实例集合
然后提出了BatchNorm的基本思想:能不能让每个隐层节点的激活输入分布固定下来
深层神经网络在做非线性变换前的激活输入值(就是那个x=WU+B,U是输入)随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者变动,之所以训练收敛慢,一般是整体分布逐渐往非线性函数的取值区间的上下限两端靠近(对于Sigmoid函数来说,意味着激活输入值WU+B是大的负值或正值),所以这导致反向传播时低层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因,而BN就是通过一定的规范化手段,把每层神经网络任意神经元输入值的分布强行拉回到均值为0方差为1的标准正态分布,其实就是把越来越偏的分布强制拉回比较标准的分布,这样使得激活输入值落在非线性函数对输入比较敏感的区域,这样输入的小变化就会导致损失函数较大的变化,意思是这样让梯度变大,避免梯度消失问题产生,而且梯度变大意味着学习收敛速度快,能大大加快训练速度。
其实一句话就是:对于每个隐层神经元,把在非线性函数映射后向取值区间极限饱和区靠拢的输入分布强制拉回到均值为0方差为1的比较标准的正态分布,使得非线性变换函数的输入值落入对输入比较敏感的区域,以此避免梯度消失问题。因为梯度一直都能保持比较大的状态,所以很明显对神经网络的参数调整效率比较高,也就是说收敛地快。
经过BN后,目前大部分Activation的值落入非线性函数的线性区内,其对应的导数远离导数饱和区,这样来加速训练收敛过程;如果都通过BN,那么不就跟把非线性函数替换成线性函数效果相同了?这意味着什么?如果是多层的线性函数变换其实这个深层是没有意义的,因为多层线性网络跟一层线性网络是等价的。这意味着网络的表达能力下降了,这也意味着深度的意义就没有了。所以BN为了保证非线性的获得,对变换后满足均值为0方差为1的x又进行了scale加上shift操作( y = s c a l e ∗ x + s h i f t y=scale*x+shift y=scale∗x+shift) 每个神经元增加了两个参数scale和shift参数,这两个参数是通过训练学习到的,意思是通过scale和shift把这个值从标准正态分布左移或者右移一点并长胖一点或者变瘦一点,每个实例挪动的程度不一样,这样等价于非线性函数的值从正中心周围的线性区往非线性区动了动,核心思想应该是想找到一个线性和非线性的较好平衡点,既能享受非线性的较强表达能力的好处,又避免太靠非线性区两头使得网络收敛速度太慢
假设对于一个深层神经网络来说,其中两层结构如下:
要对每个隐层神经元的激活值做BN,可以想象成每个隐层又加上了一层BN操作层,它位于 X = W U + B X=WU+B X=WU+B激活值获得之后,非线性函数变换之前,其图示如下:
对于Mini-Batch SGD来说,一次训练过程里面包含m个训练实例,其具体BN操作就是对于隐层内每个神经元的激活值来说,进行如下变换:
x ^ ( k ) = x ( k ) − E [ x ( k ) ] Var [ x ( k ) ] \hat{x}^{(k)}=\frac{x^{(k)}-E\left[x^{(k)}\right]}{\sqrt{\operatorname{Var}\left[x^{(k)}\right]}} x^(k)=Var[x(k)]x(k)−E[x(k)]
要注意,这里t层某个神经元的 x ( k ) x^{(k)} x(k)不是指原始输入,即不是t-1层每个神经元的输出,而是t层这个神经元的线性激活 X = W U + B X=WU+B X=WU+B,这里的U才是t-1层神经元的输出。变换的意思是:某个神经元对应的原始的激活x通过减去mini-Batch内m个实例获得的m个激活x求得的均值E(x)并除以求得的方差Var(x)来进行转换
变换后某个神经元的激活x形成了均值为0,方差为1的正态分布,目的是把值往后续要进行的非线性变换的线性区拉动,增大导数值,增强反向传播信息流动性,加快训练收敛速度。但是这样会导致网络表达能力下降,为了防止这一点,每个神经元增加两个调节参数(scale和shift),这两个参数是通过训练来学习到的,用来对变换后的激活反变换,使得网络表达能力增强,即对变换后的激活进行如下的scale和shift操作,这其实是变换的反操作:
y ( k ) = γ ( k ) x ^ ( k ) + β ( k ) y^{(k)}=\gamma^{(k)} \hat{x}^{(k)}+\beta^{(k)} y(k)=γ(k)x^(k)+β(k)
BN在训练的时候可以根据Mini-Batch里的若干训练实例进行激活数值调整,但是在推理(inference)的过程中,很明显输入就只有一个实例,看不到Mini-Batch其它实例,那么这时候怎么对输入做BN呢?因为很明显一个实例是没法算实例集合求出的均值和方差的。这可如何是好?
可以用从所有训练实例中获得的统计量来代替Mini-Batch里面m个训练实例获得的均值和方差统计量,因为本来就打算用全局的统计量,只是因为计算量等太大所以才会用Mini-Batch这种简化方式的,那么在推理的时候直接用全局统计量即可
决定了获得统计量的数据范围,那么接下来的问题是如何获得均值和方差的问题。很简单,因为每次做Mini-Batch训练时,都会有那个Mini-Batch里m个训练实例获得的均值和方差,现在要全局统计量,只要把每个Mini-Batch的均值和方差统计量记住,然后对这些均值和方差求其对应的数学期望即可得出全局统计量,即:
E [ x ] ← E B [ μ B ] Var [ x ] ← m m − 1 E B [ σ B 2 ] \begin{array}{l} E[x] \leftarrow E_{\mathrm{B}}\left[\mu_{\mathrm{B}}\right] \\ \operatorname{Var}[x] \leftarrow \frac{m}{m-1} E_{\mathrm{B}}\left[\sigma_{\mathrm{B}}^{2}\right] \end{array} E[x]←EB[μB]Var[x]←m−1mEB[σB2]
有了均值和方差,每个隐层神经元也已经有对应训练好的Scaling参数和Shift参数,就可以在推导的时候对每个神经元的激活数据计算BN进行变换了,在推理过程中进行BN采取如下方式:
y = γ Var [ x ] + ε ⋅ x + ( β − γ ⋅ E [ x ] Var [ x ] + ε ) y=\frac{\gamma}{\sqrt{\operatorname{Var}[x]+\varepsilon}} \cdot x+\left(\beta-\frac{\gamma \cdot E[x]}{\sqrt{\operatorname{Var}[x]+\varepsilon}}\right) y=Var[x]+εγ⋅x+(β−Var[x]+εγ⋅E[x])
这个公式其实和训练时的公式 y ( k ) = γ ( k ) x ^ ( k ) + β ( k ) y^{(k)}=\gamma^{(k)} \hat{x}^{(k)}+\beta^{(k)} y(k)=γ(k)x^(k)+β(k)是等价的;通过简单的合并计算推导就可以得出这个结论。那么为啥要写成这个变换形式呢?在实际运行的时候,按照这种变体形式可以减少计算量,为啥呢?因为对于每个隐层节点来说 γ Var [ x ] + ε γ ⋅ E [ x ] Var [ x ] + ε \frac{\gamma}{\sqrt{\operatorname{Var}[x]+\varepsilon}} \frac{\gamma \cdot E[x]}{\sqrt{\operatorname{Var}[x]+\varepsilon}} Var[x]+εγVar[x]+εγ⋅E[x]都是固定值,这样这两个值可以事先算好存起来,在推理的时候直接用就行了,这样比原始的公式每一步骤都现算少了除法的运算过程,乍一看也没少多少计算量,但是如果隐层节点个数多的话节省的计算量就比较多了
在深度神经网络训练的过程中,通常以输入网络的每一个mini-batch进行训练,这样每个batch具有不同的分布,使模型训练起来特别困难
Internal Covariate Shift (ICS)问题:在训练的过程中,激活函数会改变各层数据的分布,随着网络的加深,这种改变(差异)会越来越大,使模型训练起来特别困难,收敛速度很慢,会出现梯度消失的问题
全连接层或卷积操作之后,激活函数之前
加入缩放和平移变量的原因是:保证每一次数据经过归一化后还保留原有学习来的特征,同时又能完成归一化操作,加速训练
常用的 Normalization 方法:BN、LN、IN、GN