一张图读懂Batch Normalization

        为了加速网络模型的收敛,我们通常在图像的预处理过程中会对图像进行标准化处理,如下图所示,对于Conv1来说输入的特征矩阵来说就满足某一特定的分布,但对于Conv2来说的输入特征矩阵就不一定满足某一分布规律了,因为经过Conv1之后的分布规律就不确定了(这里所说的满足某一分布规律并不是指某一个feature map的数据要满足分布规律,理论上是指整个训练样本集所对应的feature map的数据要满足的分布规律)。而我们的Batch Normalization的目的就是使我们的feature map满足均值为0,方差为1的分布规律。

一张图读懂Batch Normalization_第1张图片

         “对于一个拥有d维的输入x,我们将对它的每一个维度进行标准化处理。”我们用一个通俗的例子来解释,假设我们输入的图像x是RGB三通道的彩色图像,那么这里的d就是输入图像的channels即d=3,x=\left ( x^{(1)} ,x^{(2)},x^{(3)}\right ),其中x^{(1)}就代表我们R通道所对应的特征矩阵,依此类推。标准化处理也就是分别对我们输入图像的R,G,B通道进行处理。原文给我们提供了详细的公式: 

一张图读懂Batch Normalization_第2张图片

        我们刚刚说的要让feature map满足某一分布规律,理论上是指要整个训练样本所对应的feature map的数据满足一定的分布规律,也就是要计算出整个训练集的feature map然后再进行标准化处理,因为我们在训练一个深度学习模型时数据量也就是训练集是很大的,这如果要计算出所有训练集的feature map再进行计算很明显是不可能的。那我们怎么办呢?原论文中给我们提出了一个办法,论文中所说的Batch Normalization,是我们计算一个Batch数据的feature map然后再进行标准化操作,这样就大大降低了数据的处理。当然,在这个过程中,Batch越大越接近整个训练集的分布,效果会越好一些。

        下面通过一张图来进一步理解一下Batch Normalization操作步骤:

一张图读懂Batch Normalization_第3张图片

 上边这张图可能看起来有点懵,给大家解释一下。

        假设一个batch size = 2,feature-1和feature-2是由image1、image2经过一系列卷积池化后得到的特征矩阵,每一个特征矩阵有两个channel,每一个channel的大小都是2×2。那么x^{(1)}代表该batch中所有feature的channel-1数据

x^{(1)}=\left \{ 1,1,1,2,0,-1,2,2 \right \}

同理,x^{(2)}表示该batch中所有feature的channel-2的数据

x^{(2)}=\left \{ -1,1,0,1,0,-1,3,1 \right \}

 得到x^{(1)}x^{(2)}之后,我们要分别计算他们两个的均值和方差

\mu _{1}=\frac{1}{m}\sum_{i=1}^{m}x_{i}^{(1)}=\frac{1+1+1+2+0+(-1)+2+2}{8}=1

\mu _{2}=\frac{1}{m}\sum_{i=1}^{m}x_{i}^{(2)}=\frac{(-1)+1+0+1+0+(-1)+3+1}{8}=0.5

\sigma_{1}^{2}=\frac{1}{m}\sum_{i=1}^{m}(x_{i}^{(1)}-\mu_{1})^{^{2}}=1

\sigma_{2}^{2}=\frac{1}{m}\sum_{i=1}^{m}(x_{i}^{(2)}-\mu_{2})^{^{2}}=1.5

这样我们得到两个channel的均值与方差,那么我们接下来就要根据下面的公式去计算BN层处理过后的特征矩阵的值(公式中的\varepsilon是一个很小的常量,这是为了防止出现分母等于0的情况)

\widehat{x}_{i}=\frac{x_{i}-\mu}{\sqrt{\sigma^{2}+\varepsilon}}

        但是在我们预测的过程中通常都是输入一张图片进行预测,此时的batch为1,如果再用上边的方法去计算均值和方差就没有意义了,上述方法是建立在batch size大于等于2的情况下。所以我们在训练过程中要不断地去计算每一个batch的均值和方差,并使用移动平均(moving average)的方法去记录统计的均值和方差,在迭代完成后我们近似的认为所统计的均值和方差就等于整个训练集的均值和方差。

最后给出李宏毅老师关于batch normalization的视频讲解:李宏毅深度学习(2017)_哔哩哔哩_bilibili

你可能感兴趣的:(开发语言,python,conda)