Pytorch中的BatchNorm2d的参数解释

参考链接:https://www.cnblogs.com/leebxo/p/10880399.html

BatchNorm2d中的track_running_stats参数

  • 如果BatchNorm2d的参数val,track_running_stats设置False,那么加载预训练后每次模型测试测试集的结果时都不一样;track_running_stats设置为True时,每次得到的结果都一样。

running_meanrunning_var参数

  • running_meanrunning_var参数是根据输入的batch的统计特性计算的,严格来说不算是“学习”到的参数,不过对于整个计算是很重要的。
torch.nn.BatchNorm1d(num_features,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True)

BatchNorm2d参数讲解

  • 一般来说pytorch中的模型都是继承nn.Module类的,都有一个属性trainning指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层或者Dropout层。通常用model.train()指定当前模型model为训练状态,model.eval()指定当前模型为测试状态。
  • 同时,BN的API中有几个参数需要比较关心的,一个是affine指定是否需要仿射,还有个是track_running_stats指定是否跟踪当前batch的统计特性。容易出现问题也正好是这三个参数:trainningaffinetrack_running_stats
  • 其中的affine指定是否需要仿射,也就是是否需要上面算式的第四个,如果affine=Falseγ=1,β=0 \gamma=1,\beta=0γ=1,β=0,并且不能学习被更新。一般都会设置成affine=True
  • trainningtrack_running_statstrack_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性。相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。
  • Pytorch中的BatchNorm2d的参数解释_第1张图片
  • Pytorch中的BatchNorm2d的参数解释_第2张图片

你可能感兴趣的:(Pytorch中的BatchNorm2d的参数解释)