Pytorch FrozenBatchNorm (BN)

FrozenBatchNorm就是"weight" and "bias", "running_mean", "running_var”四个值固定住的BN

经典框架中一直使用的是FrozenBatchNorm2d。如Detectron,DETR,  mmdetection?见

检测框架中BN - 知乎 (zhihu.com)

"weight" and "bias", "running_mean", "running_var”四个值是buf,通过register_buffer设置不更新。

为什么要使用FrozenBatchNorm

BN层在CNN网络中大量使用,但是BN依赖于均值和方差,如果batch_size太小,计算一个小batch_size的均值和方差,肯定没有计算大的batch_size的均值和方差稳定和有意义,这个时候,还不如不使用bn层,因此可以将bn层冻结。另外,我们使用的网络,几乎都是在imagenet上pre-trained,完全可以使用在imagenet上学习到的参数。


而且,如果使用的是FrozenBatchNorm,多卡训练就不会有BN同步的问题了,那么多卡训练的性能理论上应该和单卡一样好了,注意这点

torchvision.ops.FrozenBatchNorm2d(num_features: int, eps: float = 1e-05)

你可能感兴趣的:(pytorch,深度学习,人工智能)