终于知道为什么要freeze BN层,以及如何freeze(这个trick真的可以加快收敛)

一、什么是Batch Normalization(BN)层

BN层是数据归一化的方法,一般都是在深度神经网络中,激活函数之前,我们在训练神经网络之前,都会对数据进行预处理,即减去均值和方差的归一化操作。但是随着网络深度的加深,函数变的越来越复杂,每一层的输出的数据分布变化越来越大。BN的作用就是把数据强行拉回我们想要的比较好的正态分布下。这样可以在一定程度上避免梯度爆炸或者梯度消失的问题,加快收敛的速度。

二、BN是如何操作的

I n p u t : B = x 1... m ; γ , β ( 参 数 需 要 学 习 ) Input: B = {x_{1...m}}; \gamma, \beta(参数需要学习) Input:B=x1...m;γ,β()
O u t p u t : y i = B N γ β ( x i ) Output: {y_i = BN_{\gamma\beta}(x_i)} Output:yi=BNγβ(xi)
u B ← 1 m ∑ i = 1 m x i u_B \leftarrow \frac{1}{m}\sum_{i =1}^mx_i uBm1i=1mxi
σ B 2 ← 1 m ∑ i = 1 m ( x i − u B ) 2 \sigma_B^2 \leftarrow \frac{1}{m}\sum_{i =1}^m(x_i - u_B)^2 σB2m1i=1m(xiuB)2
x ~ ← x i − u B σ B 2 + ϵ \tilde{x} \leftarrow \frac{x_i - u_B}{\sqrt{\sigma_B^2+\epsilon}} x~σB2+ϵ xiuB
y i = γ x ~ i + β y_i = \gamma\tilde{x}_i+\beta yi=γx~i+β

BN工作流程:
1、计算当前batch_size数据的均值和方差;
2、将当前batch内的数据,normalize到均值为0,方差为1的分布上;
3、然后对normalized后的数据进行缩放和平移,缩放和平移的 γ 和 β \gamma和\beta γβ是可学习的。

BN层的状态包含4个参数:

  • weight,即缩放操作的\gamma
  • bias,缩放操作的\beta
  • running_mean,训练阶段在全训练数据上统计的均值,测试阶段会用到
  • running_var,训练阶段在全训练数据上统计的方差,测试阶段会用到

weight和bias这两个参数需要训练,而running_mean、running_val不需要训练,它们只是训练阶段的统计值。
训练时,均值、方差分别是该批次内数据相应维度的均值与方差;
推理时,均值、方差是基于所有批次的期望计算所得,

BN层的使用:
torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)
momentum:估计running_mean和 ruuning_var时使用
affine:如果为true,就学习参数 γ 和 β \gamma和\beta γβ,否则不学习。
track_running_stats:如果为true,持续跟踪running_mean,running_var

三、BN最大的作用

加快收敛。

四、为什么要freeze BN层

BN层在CNN网络中大量使用,可以看上面bn层的操作,第一步是计算当前batch的均值和方差,也就是bn依赖于均值和方差,如果batch_size太小,计算一个小batch_size的均值和方差,肯定没有计算大的batch_size的均值和方差稳定和有意义,这个时候,还不如不使用bn层,因此可以将bn层冻结。另外,我们使用的网络,几乎都是在imagenet上pre-trained,完全可以使用在imagenet上学习到的参数。

五、如何freeze BN层

有两种,一种是在训练阶段,将bn层变为eval(),即不更新统计running_mean和runn_val;另一种是需要将bn层的requires grad = False,BN层的参数weight和bias不优化,更新。
frozen: stop gradient update in norm layers
norm_eval: stop moving average statistics update in norm layers

def train(self, model=True):
  freeze_bn = False
  freeze_bn_affine = False
  supper(myNet, self).train(mode)
  if freeze_bn:
      print ("Freezing Mean/Var of BatchNorm2D.")
      for m in self.model.modules():
          if isinstance(m, nn.BatchNorm2d):
              m.eval()
      if freeze_bn_affine:
          print ("Freezeing Weight/Bias of BatchNorm2D.")
          if freeze_bn_affine:
              m.weight.requires_grad = False
              m.bias.requires_grad = False

两种freeze BN的方式,如何使用,我们来看一下《MMDetection: Open MMLab Detection Toolbox and Benchmark》里面的相关实验,在mmdetection中eval = True, requires grad = True是默认设置,不更新BN层的统计信息,也就是running_var和running_mean,但是优化更新其weight和bias的学习参数。
终于知道为什么要freeze BN层,以及如何freeze(这个trick真的可以加快收敛)_第1张图片
当GPU内存限制时,batch_size只能设置很小,例如1或者2,因此会对BN层进行freeze。上面的table6 时eval和requires_grad不同组合时的效果,该实验使用的网络是Mask R-CNN。Table 6显示,lr schedulex1时,更新统计信息,即eval = False,会损害网络性能,当eval = True,对权重weight 和 bias是否更新,即requires_grad = False or True,影响不大;但是lr_schedulex2中,eval=True, requires_grad = True 效果最好。

你可能感兴趣的:(实践,深度学习,batch,机器学习)