此系列是记录DL、ML相关的一些基础概念,便于理解和回顾!一天两道!哈哈哈~~,水滴石穿!小白水平,若有错误,欢迎指正!
B N BN BN 算法出自论文:
( B a t c h N o r m a i l z a t i o n : A c c e l e r a t i n g D e e p N e t w o r k T r a i n i n g b y R e d u c i n g I n t e r n a l C o v a r i a t e S h i f t ) (Batch \ Normailzation:Accelerating \ Deep \ Network \ Training \ by \ Reducing \ Internal \ Covariate \ Shift) (Batch Normailzation:Accelerating Deep Network Training by Reducing Internal Covariate Shift)
参考链接:Batch Normalization 学习笔记
网络训练起来,参数更新,第一个输入层数据是我们输入的,分布不会变,其余后面所有层的输入数据分布都会发生变化。网络中间层在训练过程中,数据分布的改变被论文作者称为: I n t e r n a l C o v a r i a t e S h i f t Internal \ Covariate \ Shift Internal Covariate Shift。
提出该 BN 算法的目的:
在神经网络训练的时候,需要对数据做归一化处理,为什么需要归一化和归一化的好处?
BN算法就是解决在训练过程中,中间层数据分布发生变化的情况。
BN算法的优点:
BN的本质:
非线性变换前的激活输入值(y=Wx+b)随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者变动,所以导致训练收敛慢,一般体现为:整体分布逐渐往非线性函数的取值区间的上下限两端靠近(对于 S i g m o i d Sigmoid Sigmoid 函数来说,意味着激活输入值 y=wx+b 是 -inf 或者 +inf),这导致反向传播时低层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因。而BN就是通过一定的规范化手段,把每层神经网络任意神经元这个输入值的分布强行拉回到均值为0方差为1的标准正态分布,其实就是把越来越偏的分布强制拉回比较标准的分布,这样使得激活输入值落在非线性函数对输入比较敏感的区域,这样输入的小变化就会导致损失函数较大的变化,意思是这样让梯度变大,避免梯度消失问题产生,而且梯度变大意味着学习收敛速度快,能大大加快训练速度。
BN的实现:
最有用的预处理为白化,因此有想法者提出在每层之前增加 PCA 白化,也就是 先对数据进行去相关然后再进行归一化,这样基本满足了数据的0均值、单位方差、弱相关性。但是这样是不可取的,因为在白化过程中会计算协方差矩阵、求逆等操作,计算量会很大,另外,在反向传播时,白化的操作不一定可微。于是为了简化计算,作者忽略了第1个要求,仅仅使用了下面的公式进行预处理,也就是近似白化预处理:
减去均值,除以方差,将其归一化到:均值为0,方差为1: x ^ ( k ) = x ( k ) − E [ x ( k ) ] Var [ x ( k ) ] \widehat{x}^{(k)}=\frac{x^{(k)}-\mathrm{E}\left[x^{(k)}\right]}{\sqrt{\operatorname{Var}\left[x^{(k)}\right]}} x (k)=Var[x(k)]x(k)−E[x(k)]
经过上述近白化处理后某个神经元的激活x形成了均值为0,方差为1的正态分布,目的是把值往后续要进行的非线性变换的线性区拉动,增大导数值,增强反向传播信息流动性,加快训练收敛速度。但是这样会导致网络表达能力下降。打个比方,比如我网络中间某一层学习到特征数据本身就分布在 S i g m o i d Sigmoid Sigmoid 激活函数的两侧,你强制把它给我归一化处理、标准差也限制在了1,把数据变换成分布于 S i g m o i d Sigmoid Sigmoid 函数的中间部分,这样就相当于我这一层网络所学习到的特征分布被破坏了,这可怎么办?而且这个归一化层,在网络构建中,它是一个可学习、有参数的网络层。所以:每个神经元增加两个调节参数(scale和shift),这两个参数是通过训练来学习到的,用来对变换后的激活反变换,使得网络表达能力增强,即对变换后的激活进行如下的scale和shift操作,这其实是变换的反操作: x ^ ( k ) = x ( k ) − E [ x ( k ) ] Var [ x ( k ) ] \widehat{x}^{(k)}=\frac{x^{(k)}-\mathrm{E}\left[x^{(k)}\right]}{\sqrt{\operatorname{Var}\left[x^{(k)}\right]}} x (k)=Var[x(k)]x(k)−E[x(k)] y ( k ) = γ ( k ) x ^ ( k ) + β ( k ) y^{(k)}=\gamma^{(k)} \widehat{x}^{(k)}+\beta^{(k)} y(k)=γ(k)x (k)+β(k) β ( k ) = E [ x ( k ) ] γ ( k ) = Var [ x ( k ) ] \beta^{(k)}=\mathrm{E}\left[x^{(k)}\right] \quad \gamma^{(k)}=\sqrt{\operatorname{Var}\left[x^{(k)}\right]} β(k)=E[x(k)]γ(k)=Var[x(k)]是可以恢复出原始的某一层所学到的特征的。引入可学习重构参数γ、β,让网络可以学习恢复出原始网络所要学习的特征分布。
终,BN算法可总结为:
μ B ← 1 m ∑ i = 1 m x i If mini-batch mean σ B 2 ← 1 m ∑ i = 1 m ( x i − μ B ) 2 //mini-batch variance x ^ i ← x i − μ B σ B 2 + ϵ I/ normalize y i ← γ x ^ i + β ≡ B N γ , β ( x i ) If scale and shift \begin{array}{rlrl}{\mu_{\mathcal{B}}} & {\leftarrow \frac{1}{m} \sum_{i=1}^{m} x_{i}} & {} & {\text { If mini-batch mean }} \\ {\sigma_{\mathcal{B}}^{2}} & {\leftarrow \frac{1}{m} \sum_{i=1}^{m}\left(x_{i}-\mu_{\mathcal{B}}\right)^{2}} & {\text { //mini-batch variance }} \\ {\widehat{x}_{i}} & {\leftarrow \frac{x_{i}-\mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^{2}+\epsilon}}} & {} & {\text { I/ normalize }} \\ {y_{i}} & {\leftarrow \gamma \widehat{x}_{i}+\beta \equiv \mathrm{B} \mathrm{N}_{\gamma, \beta}\left(x_{i}\right)} & {} & {\text { If scale and shift }}\end{array} μBσB2x iyi←m1∑i=1mxi←m1∑i=1m(xi−μB)2←σB2+ϵxi−μB←γx i+β≡BNγ,β(xi) //mini-batch variance If mini-batch mean I/ normalize If scale and shift
BN在训练和测试时的差别 :
对于BN,在训练时,是对每一批的训练数据进行归一化。使用BN的目的就是每个批次分布稳定。当一个模型训练完成之后,它的所有参数都确定了,包括均值和方差,gamma和bata。
而在测试时,比如进行一个样本的预测,就并没有batch的概念,因此,这个时候用的均值和方差是全量训练数据的均值和方差,也就是使用全局统计量来代替批次统计量,这个可以通过移动平均法求得。具体做法是,训练时每个批次都会得到一组(均值、方差),然后对这些数据求数学期望!每轮batch后都会计算,也称为移动平均。
BN代码实现:
def BN(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
eps = 1e-5
x_mean = torch.mean(x, dim=0, keepdim=True) # 计算均值
x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True) # 计算方差
if is_training:
x_hat = (x - x_mean) / torch.sqrt(x_var + eps) # 类白化操作
moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean # 移动平均
moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
else:
x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean) # 反变换