Normalization 系列方法(一):CV【4】:Batch normalization
Normalization 系列方法(二):CV【5】:Layer normalization
对于早前的 CNN 模型来说,大多使用 batch normalization
进行归一化,随着 Transformer
在计算机视觉领域掀起的热潮, layer normalization
开始被用于提升传统的 CNN 的性能,在许多工作中展现了不错的提升
本文主要是对 batch normalization
用法的总结
神经网络可以看成是上图形式,对于中间的某一层,其前面的层可以看成是对输入的处理,后面的层可以看成是损失函数。一次反向传播过程会同时更新所有层的权重 W 1 , W 2 , ⋯ , W L W_1, W_2, \cdots, W_L W1,W2,⋯,WL,前面层权重的更新会改变当前层输入的分布,而跟据反向传播的计算方式,我们知道,对 W k W_k Wk 的更新是在假定其输入不变的情况下进行的
如果假定第 k k k 层的输入节点只有 2 个,对第 k k k 层的某个输出节点而言,相当于一个线性模型 y = w 1 x 1 + w 2 x 2 + b y = w_1x_1 + w_2x_2 + b y=w1x1+w2x2+b,如下图所示,
假定当前输入 x 1 x_1 x1 和 x 2 x_2 x2 的分布如图中圆点所示,本次更新的方向是将直线 H 1 H_1 H1 更新成 H 2 H_2 H2,但是当前面层的权重更新完毕,当前层输入的分布换成了另外一番样子,直线相对输入分布的位置可能变成了 H 3 H_3 H3,下一次更新又要根据新的分布重新调整
深度学习这种包含很多隐层的网络结构,在训练过程中,因为各层参数不停在变化,每一层的参数更新都会导致上层的输入数据在输出时分布规律发生了变化,并且这个差异会随着网络深度增大而增大 —— 这就是 Internal Covariate Shift
每个神经元的输入数据不再是独立同分布的了,会导致如下问题:
所谓白化,就是对输入数据分布变换到0均值,单位方差的正态分布
白化过程就是对数据进行如下的操作:
之前的研究表明如果在图像处理中对输入图像进行白化(Whiten)操作的话,那么神经网络会较快收敛。
BN
作者就开始尝试:图像是深度神经网络的输入层,做白化能加快收敛,那么其实对于深度网络来说,其中某个隐层的神经元是下一层的输入,意思是其实深度神经网络的每一个隐层都是输入层,不过是相对下一层来说而已,那么能不能对每个隐层都做白化呢?
这就是启发BN产生的原初想法,而 BN
也确实就是这么做的,可以理解为对深层神经网络每个隐层神经元的激活值做简化版本的白化操作
Batch Normalization
,简称 BatchNorm
或 BN
,翻译为批归一化,是神经网络中一种特殊的层,如今已是各种流行网络的标配。在原论文中,BN
被建议插入在(每个)ReLU
激活层前面,如下所示:
Sigmoid
函数来说,意味着激活输入值 W U + B WU + B WU+B 是大的负值或正值)BN
就是通过一定的规范化手段,把每层神经网络任意神经元这个输入值的分布强行拉回到均值为 0 0 0 方差为 1 1 1 的标准正态分布
如果 batch size 为 m m m,则在前向传播过程中,网络中每个节点都有 m m m 个输出,所谓的 Batch Normalization
,就是对该层每个节点的这 m m m 个输出进行归一化再输出,具体计算方式如下:
其操作可以分成2步,
Standardization
,得到 zero mean unit variance
的分布 x ^ \hat{x} x^scale and shift
,缩放并平移到新的分布 y y y,具有新的均值 β \beta β 方差 γ \gamma γ假设BN层有 d d d 个输入节点,则 x x x 可构成 d × m d \times m d×m大小的矩阵 X X X,BN
层相当于通过行操作将其映射为另一个 d × m d \times m d×m 大小的矩阵 Y Y Y,如下所示:
其中, x i ( b ) x^{(b)}_i xi(b)表示输入当前 batch
的 b − t h b-th b−th 样本时该层 i − t h i-th i−th 输入节点的值, x i x_i xi 为 [ x i ( 1 ) , x i ( 2 ) , ⋯ , x i ( m ) ] [x_i^{(1)}, x_i^{(2)}, \cdots, x_i^{(m)}] [xi(1),xi(2),⋯,xi(m)] 构成的行向量,长度为 batch size
m m m, μ \mu μ 和 σ \sigma σ 为该行的均值和标准差, ϵ \epsilon ϵ 为防止除零引入的极小量(可忽略), γ \gamma γ 和 β \beta β 为该行的 scale
和 shift
参数,可知
scale
和 shift
参数,用于控制 y i y_i yi 的方差和均值
scale
和 shift
),这两个参数是通过训练来学习到的,用来对变换后的激活反变换,使得网络表达能力增强BN
层中, x i x_i xi 和 x j x_j xj 之间不存在信息交流 ( i ≠ j ) (i \neq j) (i=j)可见,无论 x i x_i xi 原本的均值和方差是多少,通过 BatchNorm
后其均值和方差分别变为待学习的 β \beta β 和 γ \gamma γ
下图给出了一个计算均值 μ B \mu_{\mathcal{B}} μB 和方差 σ B 2 \sigma_{\mathcal{B}}^{2} σB2 的示例:
上图展示了一个 batch size
为 2(两张图片)的 Batch Normalization
的计算过程
moving average
) 的方法记录统计的均值和方差,在训练完后我们可以近似认为所统计的均值和方差就等于整个训练集的均值和方差。对于目前的神经网络计算框架,一个层要想加入到网络中,要保证其是可微的,即可以求梯度
反向传播求梯度只需抓住一个关键点,如果一个变量对另一个变量有影响,那么他们之间就存在偏导数,找到直接相关的变量,再配合链式法则,公式就很容易写出了,如下所示:
根据反向传播的顺序,首先求取损失 ℓ \ell ℓ 对 BN
层输出 y i y_i yi 的偏导 ∂ ℓ ∂ y i \frac{\partial \ell}{\partial y_i} ∂yi∂ℓ,然后是对可学习参数的偏导 ∂ ℓ ∂ γ \frac{\partial \ell}{\partial \gamma} ∂γ∂ℓ 和 ∂ ℓ ∂ β \frac{\partial \ell}{\partial \beta} ∂β∂ℓ,用于对参数进行更新,想继续回传的话还需要求对输入 x x x 偏导,于是引出对变量 μ \mu μ、 σ 2 \sigma^2 σ2 和 x ^ \hat{x} x^ 的偏导,根据链式法则再求这些变量对 x x x 的偏导
前向传播与反向传播的过程如下图所示:
反向传播的每块详细内容(从右往左)如下所示:
在实际实现时,通常以矩阵或向量运算方式进行,比如逐元素相乘、沿某个axis求和、矩阵乘法等操作
在预测阶段,所有参数的取值是固定的,对 BN
层而言,意味着 μ \mu μ、 σ \sigma σ、 γ \gamma γ 和 β \beta β 都是固定值
mini batch
的统计量,随着输入 batch
的不同, μ \mu μ 和 σ \sigma σ 一直在变化。在预测阶段,输入数据可能只有 1 条,该使用哪个 μ \mu μ 和 σ \sigma σ,或者说,每个 BN
层的 μ \mu μ 和 σ \sigma σ 该如何取值?可以采用训练收敛最后几批 mini batch
的 μ \mu μ 和 σ \sigma σ 的期望,作为预测阶段的 μ \mu μ 和 σ \sigma σ,如下所示:
因为 standardization
和 scale and shift
均为线性变换,在预测阶段所有参数均固定的情况下,参数可以合并成 y = k x + b y = kx + b y=kx+b 的形式,如上图中行号11所示
使用 batch normalization
的优点:
bias
置为 0,因为 batch normalization
的 standardization
过程会移除直流分量,所以不再需要 bias
batch normalization
后,对与同一个输出节点相连的权重进行放缩,其标准差 σ \sigma σ 也会放缩同样的倍数,相除抵消sigmoid
和 tanh
了,理由同上,BN
抑制了梯度消失。batch normalization
具有某种正则作用,不需要太依赖 dropout
,减少过拟合使用 batch normalization
的缺点:
BN
对于 batch_size
的大小还是比较敏感的,batch_size
很小的时候,其梯度不够稳定,效果反而不好BatchNorm
?feature map
, 1 1 1 个 feature map
有 1 1 1 对 γ \gamma γ 和 β \beta β 参数,同一 batch
同 channel
的 feature map
共享同一对 γ \gamma γ 和 β \beta β 参数,若卷积层有 n n n 个卷积核,则有 n n n 对 γ \gamma γ 和 β \beta β 参数scale and shift
过程可不可以?BatchNorm
有两个过程,standardization
和 scale and shift
,前者是机器学习常用的数据预处理技术,在浅层模型中,只需对数据进行 standardization
即可Batch Normalization
可以只有 standardization
,但网络的表达能力会下降
zero mean unit variance
并不见得是最好的选择scale and shift
,有利于分布与权重的相互协调,特别地,令 γ = 1 , β = 0 \gamma = 1, \beta = 0 γ=1,β=0 等价于只用 standardization
,令 γ = σ , β = μ \gamma = \sigma, \beta = \mu γ=σ,β=μ等价于没有 BN
层scale and shift
涵盖了这 2 种特殊情况,在训练过程中决定什么样的分布是适合的,所以使用 scale and shift
增强了网络的表达能力。BN
层放在激活函数前面还是后面?BN
的论文建议将 BN
层放置在 ReLU
前,因为 ReLU
激活函数的输出非负,不能近似为高斯分布ReLU
后还好一些
BN
究竟应该放在激活的前面还是后面?以及,BN
与其他变量,如激活函数、初始化方法、dropout
等,如何组合才是最优?可能需要根据具体的实验情况,具体分析BN
层为什么有效?BN
层让损失函数更平滑
BN
更有利于梯度下降
μ s t a t i s t i c \mu _{statistic} μstatistic 和 σ s t a t i s t i c 2 \sigma _{statistic}^{2} σstatistic2 的具体更新策略如下,其中 momentum
默认取 0.1 0.1 0.1:
这里要注意一下,在 pytorch 中对当前批次 feature 进行 BN
处理时所使用的 σ n o w 2 \sigma _{now}^{2} σnow2 是总体标准差,计算公式如下:
在更新统计量 σ s t a t i s t i c 2 \sigma _{statistic}^{2} σstatistic2 时采用的 σ n o w 2 \sigma _{now}^{2} σnow2 是样本标准差,计算公式如下:
下面是使用 pytorch 做的测试,代码如下(参考博主太阳花的小绿豆):
bn_process
函数是自定义的 BN
处理方法验证是否和使用官方bn处理方法结果一致
bn_process
中计算输入 batch 数据的每个维度(这里的维度是 channel 维度)的均值和标准差(标准差等于方差开平方)import numpy as np
import torch.nn as nn
import torch
def bn_process(feature, mean, var):
feature_shape = feature.shape
for i in range(feature_shape[1]):
# [batch, channel, height, width]
feature_t = feature[:, i, :, :]
mean_t = feature_t.mean()
# 总体标准差
std_t1 = feature_t.std()
# 样本标准差
std_t2 = feature_t.std(ddof=1)
# bn process
# 这里记得加上eps和pytorch保持一致
feature[:, i, :, :] = (feature[:, i, :, :] - mean_t) / np.sqrt(std_t1 ** 2 + 1e-5)
# update calculating mean and var
mean[i] = mean[i] * 0.9 + mean_t * 0.1
var[i] = var[i] * 0.9 + (std_t2 ** 2) * 0.1
print(feature)
# 随机生成一个batch为2,channel为2,height=width=2的特征向量
# [batch, channel, height, width]
feature1 = torch.randn(2, 2, 2, 2)
# 初始化统计均值和方差
calculate_mean = [0.0, 0.0]
calculate_var = [1.0, 1.0]
# print(feature1.numpy())
# 注意要使用copy()深拷贝
bn_process(feature1.numpy().copy(), calculate_mean, calculate_var)
bn = nn.BatchNorm2d(2, eps=1e-5)
output = bn(feature1)
print(output)
参考资料1
参考资料2
参考资料3
参考资料4