BatchNorm的个人解读和Pytorch中BN的源码解析

BatchNorm已经作为常用的手段应用在深度学习中,效果显著,加快了训练速度,保证了梯度的流动,防止过拟合,降低网络对初始化权重敏感程度,减少对调参的要求。今天自己就做个总结,记录一下BatchNorm,并从Pytorch源码来看BatchNorm。

BN的灵感来源

讲解BN之前,我们需要了解BN是怎么被提出的。在机器学习领域,数据分布是很重要的概念。如果训练集和测试集的分布很不相同,那么在训练集上训练好的模型,在测试集上应该不奏效(比如用ImageNet训练的分类网络去在灰度医学图像上finetue在测试,效果应该不好)。对于神经网络来说,如果每一层的数据分布都不一样,后一层的网络则需要去学习适应前一层的数据分布,这相当于去做了domian的adaptation,无疑增加了训练难度,尤其是网络越来越深的情况。

实际上,确实如此,不同层的输出的分布是有差异的。BN的那篇论文中指出,不同层的数据分布会往激活函数的上限或者下限偏移。论文称这种偏移为internal Covariate Shift,internal指的是网络内部。
BN就是为了解决偏移的,解决的方式也很简单,就是让每一层的分布都normalize到标准高斯分布。(这里的每一层并不准确,BN是根据划分数据的集合去做Normalization,不同的划分方式也就出现了不同的Normalization,如GN,LN,IN)

BN是如何做的

BN的行为根据训练和测试不同行为而不同。

在训练中使用BN

BN中的B是batchsize,就是说BN基于mini-batch SGD,首先训练数据必须是一个批次,含有多个样本。
假设特征 F ∈ R B C H W F \in R^{BCHW} FRBCHW,在通道维度上求均值和方差
μ i = m e a n ( f b , i , h , w ) , i ∈ r a n g e ( 1 , C ) , \mu_i = mean(f_{b,i,h,w}), i \in range(1,C), μi=mean(fb,i,h,w),irange(1,C),
σ i 2 = v a r ( f b , i , h , w ) \sigma_i^2 = var(f_{b,i,h,w}) σi2=var(fb,i,h,w)

举个栗子。
BatchNorm的个人解读和Pytorch中BN的源码解析_第1张图片
BatchNorm的个人解读和Pytorch中BN的源码解析_第2张图片
以此类推
BatchNorm的个人解读和Pytorch中BN的源码解析_第3张图片
计算得到一个四维的向量,作为这个层的 μ i , σ i 2 \mu_i ,\sigma_i^2 μi,σi2
然后,
x i ^ = ( x i − μ i ) / σ i \hat{x_i} = (x_i - \mu_i) / \sigma_i xi^=(xiμi)/σi
y i ^ = ( x i ^ × r i ) + β i \hat{y_i} = (\hat{x_i} \times r_i) + \beta_i yi^=(xi^×ri)+βi
这里第二个公式绝对是有用的, r , β r, \beta r,β是要学习的参数,参与训练。为什么需要这个公式呢。
因为我们在第一个公式中减去了均值和除以方差,降低了非线性能力。第二个公式就是去补偿非线性能力的。甚至通过学习均值和方差,BN是可以还原回原来的特征。
我们在一些源码中,可以看到带有BN的卷积层,bias设置为False,就是因为即便卷积之后加上了Bias,在BN中也是要减去的,所以加Bias带来的非线性就被BN一定程度上抵消了。需要补偿。
然后再接激活函数即可。这就是完成了BN训练过程

在测试中使用BN

在训练中使用BN是要计算均值和方差的,而这两个统计量是随着样本不同而变化的。如果在测试中依然遵循这样的方式,那么无疑同一个样本在不同的batch中预测会得到不一样的概率值,这显然是不对的。
在测试中,BN根据训练过程中计算的均值和方差,使用滑动平均去记录这些值。在测试的时候统一使用记录下来的滑动平均值,这一点可以从源码中看出来。所以在TensorFlow或者Pytorch中,BN的代码分别有is_training 和 self.training字段,就是为了区别使用行为的。

举个例子。
在训练过程的第t次迭代中,我们得到了均值u和方差sigma。那么u和sigma将使用如下方式记录下来。
μ m e a n t = μ m e a n t − 1 ∗ 0.9 + 0.1 ∗ μ \mu_{mean}^t = \mu_{mean}^{t-1} * 0.9 + 0.1* \mu μmeant=μmeant10.9+0.1μ

σ m e a n t = σ m e a n t − 1 ∗ 0.9 + 0.1 ∗ σ \sigma{mean}^t = \sigma{mean}^{t-1} * 0.9 + 0.1* \sigma σmeant=σmeant10.9+0.1σ
最后得到的 μ m e a n t \mu_{mean}^t μmeant σ m e a n t \sigma{mean}^{t} σmeant作为最终的值保存下来。供测试环节使用。

BN的好处以及原因

加速训练

输出分布向着激活函数的上下限偏移,带来的问题就是梯度的降低,(比如说激活函数是sigmoid),通过normalization,数据在一一个合适的分布空间,经过激活函数,仍然得到不错的梯度。梯度好了自然加速训练。

降低参数初始化敏感

以往模型需要设置一个不错的初始化才适合训练,加了BN就不用管这些了,现在初始化方法中随便选择一个用,训练得到的模型就能收敛。

PyTorch中BN源码解析

nn.BatchNorm2d继承_BatchNorm,BatchNorm2d仅仅负责查看tensor的尺寸是否符合要求。直接跳到_BatchNorm中。

        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)

上面是构造函数的一部分,其中running_mean,running_var就是用来记录均值和方差的滑动平均值的。都是用buffer来申请储存空间不是用parameter,是因为这不参与训练。weight和bias就是 r r r, σ \sigma σ,是训练参数。

    def forward(self, input):
        self._check_input_dim(input)

        # exponential_average_factor is self.momentum set to
        # (when it is available) only so that if gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        return F.batch_norm(
            input, self.running_mean, self.running_var, self.weight, self.bias,
            self.training or not self.track_running_stats,
            exponential_average_factor, self.eps)

在前向过程中,可以看到self.training,如果是训练中使用BN,需要设置exponential_average_factor ,这个值就是我们上面讲解测试中使用bN用到的0.9。

你可能感兴趣的:(Pytorch,深度学习)