【深度学习笔记】Batch Normalization (BN)

Batch Normalization: Accelerating Deep Network Training b y Reducing Internal Covariate Shift这篇文章是谷歌2015年提出的一个深层网络训练技巧,Batch Normalization(简称BN)不仅可以加快了模型的收敛速度,而且更重要的是在一定程度缓解了深层网络中“梯度弥散”的问题(梯度弥散:0.9^{30}\approx 0.04,在BN中,通过将activation规范为均值和方差一致的手段使得原本会减小的activation的scale变大),从而使得训练深层网络模型更加容易和稳定。

【深度学习笔记】Batch Normalization (BN)_第1张图片

BN主要分为三步:

  1. 求每一个batch的数据均值和方差
  2. 使用求得的均值和方差对该批次的训练数据做归一化,获得0-1分布。其中\epsilon是为了避免分母为0。
  3. 尺度变换和偏移:将\hat{x_i}乘以\gamma调整数值大小,再加上\beta增加偏移后得到y_i,这里的\gamma控制缩放,\beta控制偏移。由于归一化后的\hat{x_i}基本会被限制在正态分布下,使得网络的表达能力下降,影响到network的capacity。为解决该问题,引入两个新的参数\gamma ,\beta,这两个参数是在训练时由网络学习得到的,如此一来,既可以改变同时也可以保持原输入,那么模型的容纳能力(capacity)就提升了。

【深度学习笔记】Batch Normalization (BN)_第2张图片

在训练时,会对同一批的数据的均值和方差进行求解,进而进行归一化操作。对于预测阶段时所使用的均值和方差,可以是来源于训练集,训练时每次计算每个batch的方差与均值,为了使得每个batch的方差与均值尽可能的接近整体分布方差与均值的估计值,这里通过滑动平均求整个训练样本的均值和方差期望值,作为我们进行预测时进行BN的的均值和方差。滑动系数为\lambda,当前batch计算的均值和方差为\mu,\sigma,那么

均值更新:\mu_{new} = \lambda\mu_{old}+\mu

方差更新,采用无偏估计:\sigma_{new} = \lambda\sigma_{old}+\sigma

在caffe的BN层中use_global_stats:如果为真,则使用保存的均值和方差,否则采用滑动平均计算新的均值和方差。该参数缺省的时候,如果是测试阶段则等价为真,如果是训练阶段则等价为假。

 

在tensorflow中,使用bn,注意以下几项:

1、训练时,模型输入参数training=True

    def forward(self, inputs, is_training=False, reuse=False):
        # set batch norm params
        batch_norm_params = {
            'decay': self.batch_norm_decay,
            'epsilon': 1e-05,
            'scale': True,
            'is_training': is_training,
            'fused': None,  # Use fused batch norm if possible.
        }

2、训练时,如果是使用var_list = tf.trainable_variables()是不包含通过滑动平均计算出的均值\mu和方差\sigma这两个参数,所以如下代码的方式,令var_list=update_vars

parser.add_argument("--update_part", nargs='*', type=str, default=['tiny_yolo/yolov3_head'],
                    help="Partially restore part of the model for finetuning. Set [None] to train the whole model.")

# define yolo-v3 model here
yolo_model = tiny_yolo(args.class_num, args.anchors)
with tf.variable_scope('tiny_yolo'):
    pred_feature_maps = yolo_model.forward(image, is_training=is_training)
loss = yolo_model.compute_loss(pred_feature_maps, y_true)
y_pred = yolo_model.predict(pred_feature_maps)


update_vars = tf.contrib.framework.get_variables_to_restore(include=args.update_part)

# set dependencies for BN ops
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss[0], var_list=update_vars, global_step=global_step)

3、测试时,模型输入参数training=False即可

 

 

参考资料:

[1] https://arxiv.org/pdf/1502.03167.pdf

[2] https://www.cnblogs.com/skyfsm/p/8453498.html

[3]深度学习中 Batch Normalization为什么效果好?https://www.zhihu.com/question/38102762

[4]caffe层解读系列——BatchNorm https://blog.csdn.net/shuzfan/article/details/52729424

[5]https://www.cnblogs.com/hrlnw/p/7227447.html

你可能感兴趣的:(机器学习)