批标准化 tf.keras.layers.BatchNormalization 中的trainable参数与training参数比较

巨坑提醒:tf.keras与tensorflow混用,trainable=False根本不起作用。正文不用看了。

摘要:

在tensorflow中,training参数和trainable参数是两个不同的参数,具有不同的约束功能。

  • training: True/False,告诉网络现在是在训练阶段还是在inference(测试) 阶段。
  • trainable:True/False, 设置这一层网络的参数变量是可在训练过程中被更新,是可训练还是不可训练。

代码1:tf.keras.layers.BatchNormalization的使用示例,及training参数和trainable的影响分析

import tensorflow.compat.v1 as tf

input = tf.ones([1, 2, 2, 3])
output = tf.keras.layers.BatchNormalization(trainable=True)(input,training=True)

# 手动添加滑动平均节点
ops = tf.get_default_graph().get_operations()
bn_update_ops = [x for x in ops if ("AssignMovingAvg" in x.name and x.type == "AssignSubVariableOp")]
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn_update_ops)


print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
print(tf.global_variables())
print(tf.trainable_variables())

"""
当 trainable=False,training=False 时,返回:
[[]]
[, , , ]
[, ]

当 trainable=False,training=True  时,返回:
[[, ]]
[, , , ]
[, ]

当 trainable=True,training=False  时,返回:
[[]]
[, , , ]
[, ]

当 trainable=True,training=True  时,返回:
[[, ]]
[, , , ]
[, ]


结论:可见影响滑动平均节点是否存在的参数主要是:training参数。当training=True时,tensorflow就会制造两个滑动平均节点出来,否则就不制造。
     而,不论滑动平均节点是否存在,tensorflow都会产生四个变量:gamma变量,beta变量和滑动平均变量,滑动方差变量。
     四个变量中,gamma变量,beta变量是可训练变量;滑动平均变量,滑动方差变量是不可训练变量--按照指定的滑动平均规则更新,与网络的损失函数没有任何关系。
     值得注意的是:(大坑提醒)无论trainable=True还是False,gamma变量,beta变量都被收集在可训练的变量集合里。这一点与tensorflow.keras的其他网络层的操作是一样的。大坑,大坑,大坑。。。
    
"""

代码2: tf.keras.layers.BatchNormalization手动添加滑动平均节点的重要性

与代码1相比,代码2的特点是,取消了手动添加滑动平均节点的环节。

import tensorflow.compat.v1 as tf

input = tf.ones([1, 2, 2, 3])
output = tf.keras.layers.BatchNormalization(trainable=True)(input,training=True)

## 手动添加滑动平均节点
# ops = tf.get_default_graph().get_operations()
# bn_update_ops = [x for x in ops if ("AssignMovingAvg" in x.name and x.type == "AssignSubVariableOp")]
# tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn_update_ops)


print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
print(tf.global_variables())
print(tf.trainable_variables())

"""
当 trainable=True,training=True  时,返回:
[]
[, , , ]
[, ]
"""

可见,即使设置了training=True,tensorflow制造了两个滑动平均节点,但是这两个节点并没有被加入到计算图上。因此,要想是批标准化实现滑动平均的计算,需要手动把这两个节点添加到计算图上。

建议:搭建完网络模型后,立马执行手动添加滑动平均节点,无论是training=True,还是training=False. 不用担心training=False时,网络会继续执行滑动平均操作,因为此时tensorflow压根不会创建这两个计算节点。

说明3: trainable在tensorflow2.0与1.*上的区别。以下说明复制自源码,我加了翻译和个人的理解,我的tensorflow版本是1.14.0.

先提炼一下这个声明想表达的意思:

  • 在tensorflow 2.0版本中,设置BatchNormalization层的trainable=False, 会导致这一批标准化层使用滑动均值和滑动方差来执行标准化。这一操作,主要是为了方便 fine-tune。具体情况如下:
    • trainable=True, training=True: 训练阶段,使用minibatch的均值方差执行标准化
    • trainable=False, training=True: 训练阶段,使用滑动均值滑动方差执行标准化。操作目的:fine-tune时,冻结已经训练好的批标准化层。
    • trainable=True, training=False: 测试阶段,使用训练期间的滑动均值滑动方差执行标准化
    • trainable=False, training=False: 测试阶段,使用训练期间的滑动均值滑动方差执行标准化
  • 在tensorflow 1.*版本中,设置BatchNormalization层的trainable=False,批标准化层将仍旧根据training参数来判断是使用minibatch的均值和方差还是使用滑动均值和方差来执行标准化。或者细化一点:
    • trainable=True, training=True: 训练阶段,使用minibatch的均值方差执行标准化
    • trainable=False, training=True: 训练阶段,使用minibatch的均值方差执行标准化
    • trainable=True, training=False: 测试阶段,使用训练期间的滑动均值滑动方差执行标准化
    • trainable=False, training=False: 测试阶段,使用训练期间的滑动均值滑动方差执行标准化

  **About setting `layer.trainable = False` on a `BatchNormalization layer:**
翻译:关于 BatchNormalization 层中 layer.trainable = False 的设置:
 
  The meaning of setting `layer.trainable = False` is to freeze the layer,  i.e. its internal state will not change during training: its trainable weights will not be updated  during `fit()` or `train_on_batch()`, and its state updates will not be run.
翻译:对于一个一般的层,设置layer.trainable = False表示冻结这一层的参数,使这一层的内部状态不随着训练过程改变,即这一层的可训练参数不被更新,也即,在`fit()` or `train_on_batch()`过程中,这一层的状态不会被更新。
 
  Usually, this does not necessarily mean that the layer is run in inference  mode (which is normally controlled by the `training` argument that can be passed when calling a layer). "Frozen state" and "inference mode" are two separate concepts.
翻译:通常,设置layer.trainable = False并不一定意味着这一层处于inference状态(测试状态),(模型是否处于inference状态,通常调用该层的call函数时用一个叫training的参数控制。)所以,“冻结状态”和“推断模式”是两种不同的概念。
 
  However, in the case of the `BatchNormalization` layer, **setting  `trainable = False` on the layer means that the layer will be subsequently run in inference mode** (meaning that it will use the moving mean and the moving variance to normalize the current batch,  rather than using the mean and variance of the current batch).
翻译:但是,在BatchNormalization中,设置trainable = False 意味着这一层会以“推断模式”运行。这就意味着,如果在训练过程中设置批标准化层的trainable = False,就意味着批标准化过程中会使用滑动均值与滑动方差来执行当前批次数据的批标准化,而不是使用当前批次的均值与方差。
----》个人理解:对于批标准化,我们希望的是,在训练过程中使用每个minibatch自己的均值与方差执行标准化,同时保持一个滑动均值与滑动方差在测试过程中使用。如果在训练过程中,设置trainable = False的话,会导致,在训练过程中,批标准化层就会使用滑动均值与方差进行批标准化。
 
  This behavior has been introduced in TensorFlow 2.0, in order to enable `layer.trainable = False` to produce the most commonly expected behavior in the convnet fine-tuning use case.
翻译:这一操作已经被引入到TensorFlow 2.0中,目的是使`layer.trainable = False`产生最期待的行为:以便在网络fine-tune中使用。
---》个人理解:在网络fine-tune中,我们希望冻结一些层的参数,仅仅训练个别层的参数。对于批标准化层来说,我们希望这一层在训练过程中仍旧使用已经训练好的滑动均值和滑动方差,而不是当前批次的均值和方差。
 
  Note that:
    - This behavior only occurs as of TensorFlow 2.0. In 1.*,  setting `layer.trainable = False` would freeze the layer but would not switch it to inference mode.

翻译: 注意:这一行为仅仅发生在TensorFlow 2.0上。在1.*版本上,设置标准化层的`layer.trainable = False`,仍旧只会冻结标准化层的gamma和beta,仍旧使用当前批次的均值和方差标准化。
--》个人理解:在1.*版本上,设置标准化层的`layer.trainable = False`,得到的操作是:
    1)标准化层的gamma和beta不被训练
    2)执行标准化时,使用的是当前批次的均值和方差,而不是滑动均值和滑动方差。
    3)滑动均值和滑动方差仍旧会被计算吗?这有待确定。
    - Setting `trainable` on an model containing other layers will recursively set the `trainable` value of all inner layers.
翻译: 当给一整个model设置trainable参数时,相当于给其内部的每个层都设置了这一相同的参数。
    - If the value of the `trainable` attribute is changed after calling `compile()` on a model, the new value doesn't take effect for this model until `compile()` is called again.
翻译:如果,model在调用“compile()”时改变了trainable参数,新的trainable参数值并不影响这个model,直到再次调用“compile()”函数。
 

 个人的应用结论:如果在执行批标准化的时候不想使用gamma和beta变量进行平移和缩放,最好还是使用tensorflow 1.*版本,并设置trainable=False就可以了。

 

你可能感兴趣的:(批标准化 tf.keras.layers.BatchNormalization 中的trainable参数与training参数比较)