tensorflow 中batch normalize 的使用

最近在学习slim,slim有个很好的地方就是:搭建网络方便,也有很多预训练模型下载。

但是最近在调slim中的resnet的时候,发现训练集有很高的accuracy(如90%),但是测试集的accuracy还是很低(如0%, 1%),这肯定不是由于欠拟合或者过拟合导致的。

后来发现是在做batch normalize的时候出了问题。
slim的使用batch normalize的时候很方便,不需要在每个卷积层后面显示地加一个batch normalize.只需要在slim里面的arg_scope中加入slim.batch_norm就可以。
如下操作就可以:

batch_norm_params = {
      'decay': batch_norm_decay,
      'epsilon': batch_norm_epsilon,
      'scale': batch_norm_scale,
      'updates_collections': tf.GraphKeys.UPDATE_OPS,
      'is_training': is_training
  }

  with slim.arg_scope(
      [slim.conv2d],
      weights_regularizer=slim.l2_regularizer(weight_decay),
      weights_initializer=slim.variance_scaling_initializer(),
      activation_fn=tf.nn.relu,
      normalizer_fn=slim.batch_norm,
      normalizer_params=batch_norm_params):
    with slim.arg_scope([slim.batch_norm], **batch_norm_params):
      ...
      ...

言归正转,要注意的地方是,在做测试的时候,如果将is_training改为 False,就会出现测试accuracy很低的现象,需要将is_training改成True。虽然这样能得到高的accuracy,但是明显不合理!!
解决方法是:
因为batch_norm 在test的时候,用的是固定的mean和var, 而这个固定的mean和var是通过训练过程中对mean和var进行移动平均得到的。而直接使用train_op会使得模型没有计算mean和var,因此正确的方式是:
每次训练时应当更新一下moving_mean和moving_var

optimizer = tf.train.MomentumOptimizer(lr,momentum=FLAGS.momentum,
                                       name='MOMENTUM')
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies([tf.group(*update_ops)]):
    # train_op = slim.learning.create_train_op(total_loss, optimizer, global_step)
    train_op = optimizer.minimize(total_loss, global_step=global_step)

这样在测试的时候即使将is_training改成False也能得到正常的test accuracy了。

当然如果你还是没看懂,就戳这个链接:https://github.com/soloice/mnist-bn/blob/master/mnist_bn.py,里面有完整的代码。
或者这个:https://github.com/tensorflow/models/blob/master/slim/train_image_classifier.py

当然,其他的用法,例如tf.contrib.layers.batch_norm里面的batch normalize应该差不多,但是我没有用过,如果你用起来出了问题,可以戳下面两个链接看看能否找到答案。
1.http://ruishu.io/2016/12/27/batchnorm/
2.https://github.com/tensorflow/tensorflow/issues/1122#issuecomment-280325584

参考文献:
[1]https://github.com/tensorflow/tensorflow/issues/1122#issuecomment-280325584
[2].http://ruishu.io/2016/12/27/batchnorm/
[3]https://github.com/soloice/mnist-bn/blob/master/mnist_bn.py
[4]https://github.com/tensorflow/models/blob/master/slim/train_image_classifier.py

你可能感兴趣的:(机器学习,python)