pytorch中batch normalisation的注意事项

torch中的各种批归一的注意事项,不间断更新20190122

含有batchnorm的网络其train和eval时效果差距大

亦可参考笔者的另一篇博文:Pytorch 深度学习 模型训练 断点继续训练时损失函数恶化或与断点差异较大

  1. 和是否zero_grad及其位置关系不大,因为这个错了,train是多半不收敛的。
  2. 主要是因为BN的输入随着训练的进行是时变的,非稳态的,除非训练完全收敛,且学习率很小,并进行了多个batch的训练,此时的running mean 和running var才会收敛到正确的值。
  3. 如果BN的动量为0.1, 那么需要多训练的batch数我认为至少是20,即0.9**20=0.1214,也就是说20个batch前的训练数据在running mean和var中所占比重约十分之一。
  4. 建议:当需要用eval运作网络时,最好先以train模式进行多个batch的前向传播,用于稳定running mean和var。

torch.nn.BatchNorm2d

  1. 输入4D的矩阵,NxCxHxW
  2. C维度取Ci时可计算得到MEANi和VERi,分别是改通道对应的均值和方差
  3. 可见该批归一化过程是通道间独立的。
  4. 所以,如果batch中N=1也是可以正常运作的,这点区别于最早的批归一文章。

你可能感兴趣的:(深度学习)