使用tensorflow 的slim模块fine-tune resnet/densenet/inception网络,解决batchnorm问题

使用tf fine-tune resnet模型

前言


使用tensorflow踩了很多的坑,尤其是使用tf的slim模块的时候,其中batchnorm的问题困挠了我很久,问题表现如下:

  • 训练结果很好,测试的时候 istraining i s − t r a i n i n g 设置成false测试结果很差,设置成true测试结果恢复正常
  • 训练结果很好,但是测试的结果要差上不少

但是tensorflow官方提供的常见的网络代码以及与训练模型都是基于slim模块建立的,使用者可以直接fine-tune这些网络比如resnet, inception, densenet, 等等。但是经常有同学在使用过程中遇到结果不尽人意或者各种奇葩问题。

本文为上述提出的两个问题做一个总结,附上我的解决方案,有问题欢迎留言。

解决方案


tensorflow的slim地址,资源如下:
使用tensorflow 的slim模块fine-tune resnet/densenet/inception网络,解决batchnorm问题_第1张图片

每个网络都有对应的代码和预训练的模型,可以直接拿来fine-tune

坑1:

对于问题:训练结果很好,测试的时候 istraining i s t r a i n i n g 设置成false测试结果很差,设置成true测试结果恢复正常。
显然了是batchnorm的问题,假设要finetune-resnet-v1-101, 网络定义如下:

with slim.arg_scope(resnet_utils.resnet_arg_scope()):
    net, end_points = resnet_v1_101.resnet_v1_101(imgs_processed,
                                                  num_classes=1000,
                                                  is_training=is_training,
                                                  global_pool=True,
                                                  output_stride=None,
                                                  spatial_squeeze=True,
                                                  store_non_strided_activations=False,
                                                  reuse=None,
                                                  scope='resnet_v1_101')

这个is_training 在测试的时候给成True,测试给为false,此参数控制网络batchnorm的使用,设置为true时,batchnorm中的beta和gama参与训练进行更新,设置成false的时候不更新,而是使用计算好的moving mean 和moving variance,关于batchnorm相关问题可以参考我的博文,因此,is_training 在测试的时候给成True,也就是在测试集上仍然更新batchnorm的参数,如果在训练集上训练的比较好了,在测试集上继续拟合,那结果肯定不会太差。

问题的原因是在测试的时候没有利用到moving mean 和moving variance,解决方案就是更新train op的时候同时更新batchnorm的op,即是在代码中做如下更改:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if update_ops:
    updates = tf.group(*update_ops)
    self.cross_entropy = control_flow_ops.with_dependencies([updates], self.cross_entropy)

这样就可以将batchnorm的更新和train op的更新放在一起,也可以使用另一种方法:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_op = slim.learning.create_train_op(cross_entropy,
                                                          optimizer,
                                                          global_step=step,
                                                          variables_to_train=all_vars)
.
.
.
sess.run([train_op, extra_update_ops, cross_entropy])

作用都是一样的,但是值得注意的是,使用slim模块的时候建立train op时最好要使用slim自带的train op,具体代码如下:

optimizer = tf.train.GradientDescentOptimizer(learning_rate=lr)
train_op = slim.learning.create_train_op(cross_entropy,
                                                          optimizer,
                                                          global_step=step,
                                                          variables_to_train=all_vars)  # 选择性训练权重

而不是使用:

train_op = tf.train.GradientDescentOptimizer(learning_rate=lr).minimize(cross_entropy)

如果问题得到解决,那么恭喜,如果是在小数据集上fine-tune,可能还会遇到问题二,训练结果很好,但是测试的结果要差上不少。

坑二:


训练结果很好,但是测试的结果要差的问题出在batchnorm的decay参数上,先看一下slim中网络的arg scope定义,在resnet utiles.py的末尾可以找到如下代码:

def resnet_arg_scope(weight_decay=0.0001,
                     batch_norm_decay=0.99, #0.997,
                     batch_norm_epsilon=1e-5,
                     batch_norm_scale=True,
                     activation_fn=tf.nn.relu,
                     use_batch_norm=True):
    batch_norm_params = {
          'decay': batch_norm_decay,
          'epsilon': batch_norm_epsilon,
          'scale': batch_norm_scale,
          'updates_collections': tf.GraphKeys.UPDATE_OPS,
          'fused': None,  # Use fused batch norm if possible.

      }

      with slim.arg_scope(
          [slim.conv2d],
          weights_regularizer=slim.l2_regularizer(weight_decay),
          weights_initializer=slim.variance_scaling_initializer(),
          activation_fn=activation_fn,
          normalizer_fn=tf.contrib.layers.batch_norm if use_batch_norm else None,
          normalizer_params=batch_norm_params):
        with slim.arg_scope([slim.batch_norm], **batch_norm_params):
          # The following implies padding='SAME' for pool1, which makes feature
          # alignment easier for dense prediction tasks. This is also used in
          # https://github.com/facebook/fb.resnet.torch. However the accompanying
          # code of 'Deep Residual Learning for Image Recognition' uses
          # padding='VALID' for pool1. You can switch to that choice by setting
          # slim.arg_scope([slim.max_pool2d], padding='VALID').
          with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:
            return arg_sc

声明,在这里我没有使用slim.batchnorm,而是使用了tf.contrib.layers.batch_norm,二者差距不大,都是一样的,当然你也可以使用自己定义的batchnorm函数。

其中最重要的一个参数就是'decay': batch_norm_decay,原始的代码是在image net上训练的,decay设置的是0.999,这个数值越大,网络训练越平缓,相对需要更多的训练时间,但是在小数据集上训练的时候可以选用较小的数值,比如0.99或者0.95。

到这里坑就填完了,有问题可以在评论区提出。

代码在我的git上,根据我之前的多GPU并行代码改的,核心部分没有变,精度计算需要自己写:
https://github.com/LDOUBLEV/TF_resnet

你可能感兴趣的:(深度学习与计算机视觉,tensorflow,python,深度学习)