Tensorflow2.0学习笔记(七)BatchNorm层

Tensorflow2.0学习笔记(七)BatchNorm层_第1张图片

(1)BN的作用

Tensorflow2.0学习笔记(七)BatchNorm层_第2张图片

从上图可以看出,Sigmoid函数在[-2,2]区间导数值在[0.1,0.25],当输入大于2或者小于2时,导数逼近于0,从而容易出现梯度弥散的现象。通过标准化后,输入值被映射在0附近区域,此处的导数不会太小,不会容易出现梯度弥散的现象。

Tensorflow2.0学习笔记(七)BatchNorm层_第3张图片

如上图所示的损失函数等高线图可知,当x1和x2分布相近时,收敛更加快速,优化轨迹更好。

结论:通过标准化后,输入值被映射在0附近区域,此处的导数不会太小,不会容易出现梯度弥散的现象;网络层输入分布相近,收敛速度更快。

(2)如何保证输入的分布相近?

Tensorflow2.0学习笔记(七)BatchNorm层_第4张图片

其中,m为Batch样本数,Batch内部的均值和方差分别为是计算出来的。

是为了防止出现除0的错误而设置的较小的数,例如le-8。为了提高BN层的表达能力,引入了缩放和平移。

参数反向传播算法自动优化,实现网络层按需要缩放和平移数据的分布的目的。

(3)前向传播

训练过程:

计算当前Batch的,计算BN层的输出见公式(1)

迭代更新全局训练数据的统计值的过程见(2)

Tensorflow2.0学习笔记(七)BatchNorm层_第5张图片

其中,momentum是需要设置的一个超参数,用于平衡更新幅度。

Momentum=0时,直接被更新为最后一个batch的

Momentum=1时,保持不变。

在tensorflow中,Momentum的默认设置为0.99。

测试过程:

其中,均来自训练过程统计或优化,在测试过程中直接使用,并不会更新。

(4)反向更新

在训练过程中,反向传播算法根据损失L求解梯度,按照更新法则自动优化

注意:对于2D的特征输入X:[b,h,w,c],BN层不是计算每一个点的而是在通道C上面统计每个通道上面的所有数据的

Tensorflow2.0学习笔记(七)BatchNorm层_第6张图片

除了C轴上面统计数据的方式,还有如下几种:

Layer Norm:统计每个样本的所有特征的均值和方差

Tensorflow2.0学习笔记(七)BatchNorm层_第7张图片

Instance Norm:统计每个样本的每个通道上特征的均值和方差。

Tensorflow2.0学习笔记(七)BatchNorm层_第8张图片

Group Norm:将通道分成若干组,统计每个样本的通道组内的特征均值和方差。

Tensorflow2.0学习笔记(七)BatchNorm层_第9张图片

(5)BN层

创建BN层:layer=layers.BatchNormalization()

由于BN在训练和测试过程的行为不同,需要通过设置training标志来区分。

 

 

参考资料:Tensorflow 深度学习  龙龙老师

 

你可能感兴趣的:(Tensorflow2.0)