Pytorch中批规范化(nn.BatchNorm2d())

有时模型训练好了,将训练完成后的参数读入网络做测试的时候发现效果变差。这极有可能就是BatchNorm出现问题。下面就对pytorch中的nn.BatchNorm2d()做一个详解。
这里先放上原文链接(大部分参考这篇文章)

torch.nn.BatchNorm2d(num_features, eps=1e-05,momentum=0.1,affine=True, track_running_stats=True)

批标准化的过程如下:
Pytorch中批规范化(nn.BatchNorm2d())_第1张图片

  • num_features为输入四维tensor的通道数
  • affine指定是否需要仿射。如果affine=True,那么γ和β将作为可以被训练的参数参与学习训练;如果affine=False,则γ=1 β=0,并且不能学习被更新。
    在pytorch中,分别用weight和bias来表示γ和β。其数据维数由num_features决定:
bn=nn.BatchNorm2d(10)
print(bn.weight.size())#torch.Size([10])
print(bn.bias.size())#torch.Size([10])

由此可以看到,BN层中的仿射变换是逐通道进行的。

  • training和track_running_stats
    一般来说pytorch中的模型都是继承nn.Module类的,都有一个属性training指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层或者Dropout层。通常用model.train()指定当前模型model为训练状态,model.eval()指定当前模型为测试状态。
    track_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性。相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。
    pytorch用running_mean表示均值,running_var表示方差。
  1. training=True, track_running_stats=True。这个是期望中的训练阶段的设置,此时BN将会跟踪整个训练过程中batch的统计特性。
  2. training=True, track_running_stats=False。此时BN只会计算当前输入的训练batch的统计特性,可能没法很好地描述全局的数据统计特性。
  3. training=False, track_running_stats=True。这个是期望中的测试阶段的设置,此时BN会用之前训练好的模型中的(假设已经保存下了)running_mean和running_var并且不会对其进行更新。一般来说,只需要设置model.eval()其中model中含有BN层,即可实现这个功能。
  4. training=False, track_running_stats=False 效果同(2),只不过是位于测试状态,这个一般不采用,这个只是用测试输入的batch的统计特性,容易造成统计特性的偏移,导致糟糕效果。

注意,BN层中的running_mean和running_var的更新是在forward()操作中进行的,而不是optimizer.step()中进行的,因此如果处于训练状态,就算你不进行手动step(),BN的统计特性也会变化的。

bn=nn.BatchNorm2d(10)
print(bn.running_mean.size())#torch.Size([10])
print(bn.running_var.size())#torch.Size([10])

由此可以看到,BN层中的running_mean和running_var是按通道记录的。

如果加载训练好的参数,需要用model.eval()将模型转到测试阶段,才能固定住running_mean和running_var。

假设一个场景,如下图所示:
Pytorch中批规范化(nn.BatchNorm2d())_第2张图片
此时为了收敛容易控制,先预训练好模型model_A,并且model_A内含有若干BN层,后续需要将model_A作为一个inference推理模型和model_B联合训练,此时就希望model_A中的BN的统计特性值running_mean和running_var不会乱变化,因此就必须用model_A.eval()设置到测试模式,否则在training模式下,就算是不去更新该模型的参数,其BN都会改变的,这个将会导致和预期不同的结果。

转自:https://www.cnblogs.com/leebxo/p/10880399.html

你可能感兴趣的:(Pytorch中批规范化(nn.BatchNorm2d()))