tensorflow 中batch normalize 的使用

最近在学习slim,slim有个很好的地方就是:搭建网络方便,也有很多预训练模型下载。

但是最近在调slim中的resnet的时候,发现训练集有很高的accuracy(如90%),但是测试集的accuracy还是很低(如0%, 1%),这肯定不是由于欠拟合或者过拟合导致的。


因为batch_norm 在test的时候,用的是固定的mean和var, 而这个固定的mean和var是通过训练过程中对mean和var进行移动平均得到的。而直接使用train_op会使得模型没有计算mean和var,因此正确的方式是:

每次训练时应当更新一下moving_mean和moving_var

optimizer = tf.train.MomentumOptimizer(lr,momentum=FLAGS.momentum,

                                      name='MOMENTUM')

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

with tf.control_dependencies([tf.group(*update_ops)]):

    # train_op = slim.learning.create_train_op(total_loss, optimizer, global_step)

    train_op = optimizer.minimize(total_loss, global_step=global_step)


这样在测试的时候即使将is_training改成False也能得到正常的test accuracy了。

你可能感兴趣的:(tensorflow 中batch normalize 的使用)