Tensorflow.slim 库中 batch_normalization 的问题及其替代方法

Batch Normalization 的公式

y = γ x − μ σ + ϵ + β y = \gamma \frac{x-\mu}{\sigma + \epsilon}+\beta y=γσ+ϵxμ+β

  • 参数: x x x 是输入数据, y y y 是 batch normalization 的结果, γ \gamma γ β \beta β 是可学习的参数确保数据被过度正则化为均值 0 方差 1 的数据,而 μ \mu μ σ \sigma σ 是数据相关的确保当前训练数据正则化的效果。

Slim 的问题

Tensorflow.slim 库是集成度比较高的 API,其中实现 batch normalization 可以直接集成在 slim.conv2d函数中,如:

import tensorflow.contrib.slim as slim
slim.conv2d(input, num_ouput, kernel_size, stride, padding, rate, activation_fn,
			normalizer_fn=None, normalizer_params=None, ...)

即将 batch normalization 内容作为参数放在高度集成的卷积操作中,但是博主发现在我使用的 Tensorflow 1.10 版本中存在着 gamma 和 beta 参数更新错误的问题,即使采用了 https://blog.csdn.net/shwan_ma/article/details/83502333 等方法,也没有办法在 testing 的时候采用正确的参数。

解决方法

采用相比 slim 封装低一些的 layers 库,即 tf.layers.batch_normalization()

input = tf.placeholder(dtype=tf.float32, shape=[32,10,10,1], name='input') # 定义输入
output = tf.layers.batch_normalization(input, training=True, name='batch_normalization')

注意点

  • 参数 training,表明其中控制是否训练 γ \gamma γ β \beta β (一般在训练过程中) 。其中两个可训练的参数 γ \gamma γ β \beta β, 以及非训练的参数 滑动平均 μ \mu μ 和滑动方差 σ \sigma σ 可以由下面代码查看:
for n in tf.global_variables():
    print(n.name + ' with ', end=""), print(n.shape)
# Print: batch_normalization/gamma:0 with (32,)
#		 batch_normalization/beta:0 with (32,)
#        batch_normalization/moving_mean:0 with (32,)
#        batch_normalization/moving_variance:0 with (32,)

for n in tf.global_variables():
    print(n.name + ' with ', end=""), print(n.shape)
# Print: batch_normalization/gamma:0 with (32,)
#		 batch_normalization/beta:0 with (32,)
  • 训练时需要考虑更新滑动平均 μ \mu μ σ \sigma σ (滑动平均在测试时候用到),在调用优化器之前使用控制流确保滑动平均已经计算完成:
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops): # 控制流一定确保先执行
    train_op = optimizer.minimize(loss)

题外话:面试官经常问的 —— BN到底是对那个维度做的,得到的参数是多少维的?

  • BN 的参数的维度 = 特征的 channel 维度,可以通过上面的 print 结果得到。 因此是 batch normalization 是对每个 channel 维度做的。
  • 稍微吐槽一下 —— 个人认为这些问题毫无意义,只是面试官为了考察而考察。

你可能感兴趣的:(Tensorflow小Tips,机器学习)