cs231n -- Batch Normalization

学习Batch Normalization时,对BN的工作原理没有弄清楚,查阅了不少资料后才对它有了较为深入的理解,这里分享一下我自己对于BN的理解,希望能给同样有困惑的同学帮助,大家多多交流学习。

附上原论文地址,建议有时间的话看一遍,会对过程的计算有更好的了解。传送门

1.What is Batch Normalization(BN)

Batch Normalization(批归一化),像FC-layer、activation function一样属于网络的一个层,常用在非线性层之前。BN有两个操作:

    “normalize”,通过对mini-batch进行zscore(减均值,除标准差)化为unit Gaussian(均值为0,方差为1)。

    “scale and shift”,对归一化后的标准高斯分布进行放缩和平移操作(y=scale*x+shift),scale和shift两个参数通过学习得到。

嗯,说到底也就这么两步,不过竟然就可以有强大的魔力使我们可以选择更高的学习率、拥有更快的收敛速度,而且由于自身可视为正则化的一种形式,减轻了我们对dropout、L2的需求。

2.Motivation and why it works

从论文名字知道,BN是用来解决“Internal Covariate Shift”(ICS)问题的,那么什么是ICS?

    我们知道在训练网络时对数据进行预处理,如whitening、PCA、zscore甚至仅仅去均值都能加速收敛,因为模型的输入特赠不相关且满足标准正态分布时,模型的表现一般较好。而我们对数据进行预处理是使网络第一层有较好的输入特征,但随着模型层数的加深,网络的非线性变换使得每一层的结果不再满足N(0, 1)分布,隐藏层的层间变化也越来越大,但它们的label都是一样的,如此网络还需要学习如何去迎合分布的变化,增大训练的难度,这便是“Covariate Shift”又因为对层间信号分析,也即是“internal”的由来。

由此提出BN的基本思想:让每个隐藏层的输入特征分布固定下来。数据是网络的输入,我们对数据进行zscore可以加速收敛;在网络中,某个隐藏层可以看作是下一层的输入,那么对隐藏层进行zscore... 再往里看,其实BN真正解决的问题是“梯度弥散”或“梯度爆炸”,举个栗子,\[{\text{0}}{\text{.}}{{\text{9}}^{{\text{30}}}} \approx {\text{0}}{\text{.04}}\],在训练过程中,整体分布会逐渐往非线性函数的取值区间的上下限两端靠近,导致深层梯度消失,这是训练深层网络收敛越来越慢的原因。而BN通过正则化,将分布重新变成标准正态分布,使得其落在非线性函数对输入比较敏感的区域,将原本会减小的scale放大,如此梯度增大,收敛加快;同时也可视为一种regulazation的方式。

那么问题又来了,如此岂不是相当于将非线性函数替换为了线性函数,那网络最后的结果跟一层线性网络没有区别,这样网络的表达能力大为下降。因此BN为了网络的非线性表达能力,对变换后的unit Gaussian进行“scale and shift”,通过这两个参数将分布向非线性区移动,核心应该是想找到一个较好的平衡点,既能享受非线性较强表达能力的好处,又避免太靠近非线性区两头使得网络收敛速度太慢。由于是两个可学习的参数,使得新生成的分布能够更灵活,可能会使模型训练的效率更高。

3.How BN

我们已经知道了为什么要进行BN以及为什么BN可以较好地工作,接下来看看如何具体实现BN。

forward propagation

cs231n -- Batch Normalization_第1张图片

首先计算出mini-batch的均值与方差,接着对mini-batch中的每个特征做标准化处理,最后利用缩放参数与偏移参数对特征进行后处理得到输出。


当模型进行训练的时候,记录每一个mini-batch的均值方差,在预测时利用均值与方差的无偏估计量来进行BN操作,也就是


输出经过等效替换也就表示为

backward propagation

还是公式镇楼

cs231n -- Batch Normalization_第2张图片

链式法则一个个算就行

'\begin{align}

'\begin{align}

'\begin{align}

若仔细观察式(3)与式(4),我们令


则可以将式(3)与式(4)简化为


这样做一个简单的替换,在实现代码的时候,运算会简化很多。

4.When BN?

训练网络时遇到收敛速度很慢或者梯度爆炸等难以训练的情况时可以尝试BN。另外,一般情况下使用BN来加快训练速度、防止过拟合来提高模型精度。

最后附上论文一张图看看效果,的确是功能强大啊嘎嘎。

cs231n -- Batch Normalization_第3张图片



你可能感兴趣的:(cs231n)