简单来说BatchNorm
对输入的特征按照通道计算期望
和方差
(第1和第2个公式),并标准化(第3个公式,减去均值,再除方差,变为均值0,方差1)。但这会降低网络的表达能力,因此,BN
在标准化后还要进行缩放平移
,也就是可学习的参数 γ \gamma γ和 β \beta β,也对应每个通道。
BatchNorm
的原理并不清楚,可能是降低了Internal Covariate Shift
,也可能是使得optimization landscape变得平滑。
learning rate
、降低初始化参数的要求并可以构建更深更宽的网络;BatchNorm 在训练时,仅用当前Batch的均值和方差,而测试推理时,使用EMA计算的均值和方差。
以nn.BatchNorm2d
为例。其继承关系为:Module
→ \to →_NormBase
→ \to →_BatchNorm
→ \to →BatchNorm2d
。Module
是所有PyTorch
构建网络模块的父类。
_NormBase
主要是注册和初始化参数
class _NormBase(Module):
"""Common base of _InstanceNorm and _BatchNorm"""
def __init__(
self,
num_features: int, # 特征通道数
eps: float = 1e-5, # 防止分母为0
momentum: float = 0.1, #
affine: bool = True, # 标准化后是否进行缩放,是否使用\gamma 和 \beta
track_running_stats: bool = True, # 使用均值方差进行标准化
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(_NormBase, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = Parameter(torch.empty(num_features, **factory_kwargs)) # 注册\gamma,后续初始化为1
self.bias = Parameter(torch.empty(num_features, **factory_kwargs)) # 注册\beta,后续初始化为0
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs)) # 注册期望,后续初始化为0
self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs)) # 注册方差,后续初始化为1
self.running_mean: Optional[Tensor]
self.running_var: Optional[Tensor]
self.register_buffer('num_batches_tracked',
torch.tensor(0, dtype=torch.long,
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
self.num_batches_tracked: Optional[Tensor]
else:
self.register_buffer("running_mean", None)
self.register_buffer("running_var", None)
self.register_buffer("num_batches_tracked", None)
self.reset_parameters()
def reset_running_stats(self) -> None:
if self.track_running_stats:
# running_mean/running_var/num_batches... are registered at runtime depending
# if self.track_running_stats is on
self.running_mean.zero_() # type: ignore[union-attr]
self.running_var.fill_(1) # type: ignore[union-attr]
self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]
# 参数初始化,\gamma 为 1,\beta 为 0.
def reset_parameters(self) -> None:
self.reset_running_stats()
if self.affine:
init.ones_(self.weight)
init.zeros_(self.bias)
def _check_input_dim(self, input):
raise NotImplementedError
调用nn.functional.batch_norm
对每个通道进行计算:
class _BatchNorm(_NormBase):
def __init__(
self,
num_features,
eps=1e-5,
momentum=0.1, # 见下一章节
affine=True,
track_running_stats=True,
device=None,
dtype=None
):
factory_kwargs = {'device': device, 'dtype': dtype}
super(_BatchNorm, self).__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
)
def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)
# exponential_average_factor is set to self.momentum
# (when it is available) only so that it gets updated
# in ONNX graph when this node is exported to ONNX.
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None: # type: ignore[has-type]
self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type]
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
r"""
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
"""
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)
r"""
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
"""
return F.batch_norm(
input,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean
if not self.training or self.track_running_stats
else None,
self.running_var if not self.training or self.track_running_stats else None,
self.weight,
self.bias,
bn_training,
exponential_average_factor,
self.eps,
)
特化了输入检查
class BatchNorm2d(_BatchNorm):
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError("expected 4D input (got {}D input)".format(input.dim()))
按照Pytorch注释,momentum参与running_mean
和running_var
的计算。置为None
时,简单计算平均(累积移动平均)。默认值为0.1。
在_BatchNorm
中,赋值给了
exponential_average_factor = self.momentum
当其不为None时,也就是指数平均(Exponential Moving Average, EMA)。其计算公式为:
x ˉ t = β μ t + ( 1 − β ) x ˉ t − 1 \bar{x}_t = \beta \mu_t + (1-\beta)\bar{x}_{t-1} xˉt=βμt+(1−β)xˉt−1
其中, μ t \mu_t μt是当前Batch
的均值或方差, β \beta β为exponential_average_factor。展开
x ˉ t = β μ t + ( 1 − β ) ( β μ t − 1 + ( 1 − β ) ( β μ t − 2 + ( 1 − β ) x ˉ t − 3 ) ) = β μ t + ( 1 − β ) β μ t − 1 + ( 1 − β ) 2 β μ t − 2 + . . . + ( 1 − β ) t β μ 0 \begin{aligned} \bar{x}_t &= \beta \mu_t + (1-\beta)(\beta \mu_{t-1} + (1-\beta)(\beta \mu_{t-2} + (1-\beta)\bar{x}_{t-3}))\\\\ &= \beta \mu_t + (1-\beta)\beta \mu_{t-1} + (1-\beta)^2\beta \mu_{t-2} + ... + (1-\beta)^t\beta \mu_0 \end{aligned} xˉt=βμt+(1−β)(βμt−1+(1−β)(βμt−2+(1−β)xˉt−3))=βμt+(1−β)βμt−1+(1−β)2βμt−2+...+(1−β)tβμ0
从公式可以看出,越靠近当前的数据占的比重越大,比重按指数衰减。其值约等于最近
1 β \frac{1}{\beta} β1
次的均值。