透彻理解BN(Batch Normalization)层

什么是BN

Batch Normalization是2015年论文《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》中提出的数据归一化方法,往往用在深度神经网络中激活层之前。其作用可以加快模型训练时的收敛速度,使得模型训练过程更加稳定,避免梯度爆炸或者梯度消失。并且起到一定的正则化作用,几乎代替了Dropout

批量归一化:通过减少内部协变量偏移来加速深度网络训练

由于每层输入的分布在训练过程中随着前一层的参数发生变化而发生变化,因此训练深度神经网络很复杂。由于需要较低的学习率和仔细的参数初始化,这会减慢训练速度,并且使得训练具有饱和非线性的模型变得非常困难。我们将这种现象称为内部协变量偏移,并通过归一化层输入来解决该问题。我们的方法的优势在于将标准化作为模型架构的一部分,并为每个训练小批量执行标准化。 Batch Normalization 允许我们使用更高的学习率,并且在初始化时不那么小心。它还充当正则化器,在某些情况下消除了 Dropout 的需要。应用于最先进的图像分类模型,批量归一化在训练步骤减少 14 倍的情况下实现了相同的精度,并且以显着的优势击败了原始模型。使用一组批量归一化网络,我们改进了 ImageNet 分类的最佳发布结果:达到 4.9% 的前 5 名验证错误(和 4.8% 的测试错误),超过了人工评估员的准确性

机器学习领域有个很重要的假设:IID独立同分布假设,就是假设训练数据和测试数据是满足相同分布的,这是通过训练数据获得的模型能够在测试集获得好的效果的一个基本保障。那BatchNorm的作用是什么呢?BatchNorm就是在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布的。

BN解决“Internal Covariate Shift”问题

在训练的过程中,即使对输入层做了归一化处理使其变成标准正态,随着网络的加深,函数变换越来越复杂,许多隐含层的分布还是会彻底放飞自我,变成各种奇奇怪怪的正态分布,并且整体分布逐渐往非线性函数(也就是激活函数)的取值区间的上下限两端靠近。对于sigmoid函数来说,就意味着输入值是大的负数或正数,这导致反向传播时底层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因。

为了解决上述问题,又想到网络的某个隐含层相对于之后的网络就相当于输入层,所以BN的基本思想就是:把网络的每个隐含层的分布都归一化到标准正态。其实就是把越来越偏的分布强制拉回到比较标准的分布,这样使得激活函数的输入值落在该激活函数对输入比较敏感的区域,这样一来输入的微小变化就会导致损失函数较大的变化。通过这样的方式可以使梯度变大,就避免了梯度消失的问题,而且梯度变大意味着收敛速度快,能大大加快训练速度。

简单说来就是:传统的神经网络只要求第一个输入层归一化,而带BN的神经网络则是把每个输入层(把隐含层也理解成输入层)都归一化。

BN的核心公式理解

透彻理解BN(Batch Normalization)层_第1张图片
透彻理解BN(Batch Normalization)层_第2张图片

pytorch BatchNorm2d

BATCHNORM2D
透彻理解BN(Batch Normalization)层_第3张图片

参数介绍

  • num_features,输入数据的通道数,归一化时需要的均值和方差是在每个通道中计算的
  • eps,用来防止归一化时除以0
  • momentum,滑动平均的参数,用来计算running_mean和running_var
  • affine,是否进行仿射变换,即缩放操作
  • track_running_stats,是否记录训练阶段的均值和方差,即running_mean和running_var

BN层的状态包含五个参数

  • weight,缩放操作的γ
  • bias,缩放操作的β
  • running_mean,训练阶段统计的均值,测试阶段会用到。
  • running_var,训练阶段统计的方差,测试阶段会用到
  • num_batches_tracked,训练阶段的batch的数目,如果没有指定momentum,则用它来计算running_mean和running_var。一般momentum默认值为0.1,所以这个属性暂时没用。

weight和bias这两个参数需要训练,而running_mean、running_val和num_batches_tracked不需要训练,它们只是训练阶段的统计值

训练与推理时BN中的均值、方差分别是什么

训练时,均值、方差分别是该批次内数据相应维度的均值与方差;
推理时,均值、方差是基于所有批次的期望计算所得,公式如
在这里插入图片描述

BN两大效果

  • 收敛速率增加
  • 可以达到更好的精度

参考文档

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
【基础算法】六问透彻理解BN(Batch Normalization)
Pytorch-BN层详细解读
【深度学习】深入理解Batch Normalization批标准化

你可能感兴趣的:(pytorch,batch,神经网络,深度学习)