pytorch中BatchNorm2d的用法

CLASS torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

基本原理为:
pytorch中BatchNorm2d的用法_第1张图片
如图所示,块标准化的目的就是让传输的数据合理的分布,加速训练的过程。
输入为一个四维数据(N,C,H,W),N-输入的batch size,C是输入的图像的通道数,(H,W)为输入的图像的尺寸。
对于每一个输入特征通道,所有样本的特征图做归一化处理。
这里有个博客解释的很好
https://www.jianshu.com/p/fcc056c1c200

  • num_features – C

  • eps 默认1e-5,加在分母上保持数据稳定(不会出现分母为0 的错误)

  • momentum 默认0.1,在训练中用于对均值和方差的估计
    pytorch中BatchNorm2d的用法_第2张图片

  • affine 默认为True,表示参数可学习(即γ和β)

  • track_running_stats 默认True ,表示训练时要对均值和方差进行估计(这里我的理解是有点类似于指数加权平均的意思)
    track_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性。相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。

你可能感兴趣的:(pytorch学习笔记)