pytorch batch norm的running_mean和running_var

1、官方源码

def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor:
          if training:
              norm_mean, norm_var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum)
          else:
              norm_mean = torch._unwrap_optional(running_mean)
              norm_var = torch._unwrap_optional(running_var)
          norm_mean = torch._ncf_unsqueeze(norm_mean, input.dim())
          norm_var = torch._ncf_unsqueeze(norm_var, input.dim())
          norm_invstd = 1 / (torch.sqrt(norm_var + eps))
          return ((input - norm_mean) * norm_invstd)

2、源码解析

先前batch的滑动统计值记为:mean_old, var_old
当前batch的统计值记为:mean_new, var_new
训练时
running_mean = (1 - momentum) * mean_old + momentum * mean_new
running_var = (1 - momentum) * var_old + momentum * var_new
测试时
running_mean = mean_old
running_var = var_old
先更新running_mean和running_var,再计算bn

3、讨论

官方实现是“先更新running_mean和running_var,再计算bn”。
如果“先计算bn,再更新running_mean和running_var”会不会更好一些?当前batch不参与当前bn计算,可以消除训练和测试的bn差异。

你可能感兴趣的:(深度学习,人工智能)