TensorFlow中batch norm踩坑

最近在写多卡的TensorFlow版I3D的代码,其中遇到batch norm的坑,记录一波。

I3D使用的是snt.BatchNorm,当is_training = True时,意味着创建Update ops,利用当前batch的均值和方差去更新moving averages(即某层累计的平均均值和方差)。这里提供两种方式创建update_ops,

一是自己显式的创建update_ops,手动更新。update_ops默认放置在tf.GraphKeys.UPDATE_OPS中,因此这里在执行train_ops的同时更新均值方差即可,对于单卡来说很容易理解,对于多卡来说,相当于collection所有卡的batch的均值方差后统一更新,也可以只collection第一块卡的均值方差(理论上需要积累其他卡,但是由于这操作积累得很快,所以只取第一块卡也不影响性能,在TensorFlow高阶API的样例代码cifar10_main.py中如是说)。代码如下:

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):

      train_op = optimizer.minimize(loss)

或者

update_ops = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS))

      train_op = tf.group(train_op, update_ops)

二是自动的更新, 只需在初始化前 bn = BatchNorm(update_ops_collection=None)即可。不过这种方式下,会在完成更新前阻塞网络的forward,因此会带来时间上的成本。具体而言,这时bn的参数mean,var是立即更新的,也是计算完当前layer的mean,var就更新,然后进行下一个layer的操作。这在单卡下没有问题的, 但是多卡情况下就会写等读的冲突,因为可能存在GPU0更新(写)mean但此时GPU1还没有计算到该层,所以GPU0就要等GPU1读完mean才能写。

你可能感兴趣的:(TensorFlow中batch norm踩坑)