tensorflow中batch_normalization的正确使用姿势

原理

batch_normalization一般是用在进入网络之前,它的作用是可以将每层网络的输入的数据分布变成正态分布,有利于网络的稳定性,加快收敛。

具体的公式如下: γ ( x − μ ) σ 2 + ϵ + β \frac{\gamma(x-\mu)}{\sqrt{\sigma^2+\epsilon}}+\beta σ2+ϵ γ(xμ)+β

其中 γ \gamma γ β \beta β是决定最终的正态分布,分别影响了方差和均值, ϵ \epsilon ϵ是为了避免出现分母为0的情况

tensorflow

在真实的使用中,均值 μ \mu μ和标准差 σ \sigma σ是由历史累计样本和当前批次样本来共同决定的:

μ = m o m e n t u m ∗ μ + ( 1 − m o m e n t u m ) ∗ μ b a t c h \mu=momentum*\mu+(1-momentum)*\mu_{batch} μ=momentumμ+(1momentum)μbatch

σ = m o m e n t u m ∗ σ + ( 1 − m o m e n t u m ) ∗ σ b a t c h \sigma=momentum*\sigma+(1-momentum)*\sigma_{batch} σ=momentumσ+(1momentum)σbatch

μ b a t c h \mu_{batch} μbatch表示当前批次样本的均值

API

在tensorflow中,推荐的api是

tf.layers.batch_normalization(
    inputs, axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True,
    beta_initializer=tf.zeros_initializer(),
    gamma_initializer=tf.ones_initializer(),
    moving_mean_initializer=tf.zeros_initializer(),
    moving_variance_initializer=tf.ones_initializer(), beta_regularizer=None,
    gamma_regularizer=None, beta_constraint=None, gamma_constraint=None,
    training=False, trainable=True, name=None, reuse=None, renorm=False,
    renorm_clipping=None, renorm_momentum=0.99, fused=None, virtual_batch_size=None,
    adjustment=None
)

看几个关键的参数:

  1. momentum:对应上述公式,决定历史累计样本和当前批次样本的权重;
  2. epsilon: ϵ \epsilon ϵ是为了避免出现分母为0的情况
  3. center:是否加入 β \beta β
  4. scale:是否加入 γ \gamma γ
  5. training:当前是否为训练阶段,决定均值和方差是否固定
  6. trainable:是否将 γ \gamma γ β \beta β加到训练变量中

正确使用方式

γ \gamma γ β \beta β是可训练变量,存放于tf.GraphKeys.TRAINABLE_VARIABLES

而均值和方差则不是训练变量,只能在tf.GraphKeys.GLOBAL_VARIABLES中,并且更新过程存放于tf.GraphKeys.UPDATE_OPS

所以,最关键的点,也是最容易出问题的,就是:

  1. 训练阶段,要保证均值和方差的正确更新;
  2. 预测阶段,则要保证所有参数与训练阶段的一致,其实主要就4个: γ 、 β 、 μ 、 σ \gamma、\beta、\mu、\sigma γβμσ

训练

那么,在训练的时候,需要将更新过程加入到train_op中:

x_norm = tf.layers.batch_normalization(x, training=True)

# ...

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_op = optimizer.minimize(loss)
train_op = tf.group([train_op, update_ops])

模型保存

由于均值和方差是GLOBAL_VARIABLES,但是tensorflow默认只保存TRAINABLE_VARIABLES,所以,我们需要设置将所有变量保存起来,即GLOBAL_VARIABLES

sess = tf.Session()

saver = tf.train.Saver(tf.global_variables())

saver.save(sess, "your_path")

预测

如果,模型正确保存了全局变量GLOBAL_VARIABLES,那么预测阶段,即可加载已经训练有素的batch_normalzation相关的参数;

但是,除此之外,还要将training设为False,将均值和方差固定住。

x_norm = tf.layers.batch_normalization(x, training=False)

# ...

saver = tf.train.Saver(tf.global_variables())

saver.restore(sess, "your_path")

estimator

如果你使用的是高阶API:estimator进行训练的话,那么就比较麻烦,因为它的session没有暴露出来,你没办法直接使用,需要换个方式:

  1. 幸好的是,estimator默认保存的是所有变量GLOBAL_VARIABLES;
  2. 关键在于保证eval、predict阶段要保证加载训练好的参数。在你model_fn函数中,增加一步模型的加载
def model_fn_build(init_checkpoint=None, lr=0.001, model_dir=None):

    def _model_fn(features, labels, mode, params):

        x = features['inputs']
        y = features['labels']

        #####################在这里定义你自己的网络模型###################
        x_norm = tf.layers.batch_normalization(x, training=mode == tf.estimator.ModeKeys.TRAIN)
        pre = tf.layers.dense(x_norm, 1)
        loss = tf.reduce_mean(tf.pow(pre - y, 2), name='loss')
        ######################在这里定义你自己的网络模型###################

        lr = params['lr']
				
        ######################进入eval和predict之前,都经过这一步加载过程###################
        
        # 加载保存的模型
        # 为了加载batch_normalization的参数,需要global_variables
        tvars = tf.global_variables()
        initialized_variable_names = {}

        if params['init_checkpoint'] is not None or tf.train.latest_checkpoint(model_dir) is not None:
            checkpoint = params['init_checkpoint'] or tf.train.latest_checkpoint(model_dir)
            (assignment_map, initialized_variable_names
             ) = get_assignment_map_from_checkpoint(tvars, checkpoint)
            tf.train.init_from_checkpoint(checkpoint, assignment_map)

        # tf.logging.info("**** Trainable Variables ****")
        # for var in tvars:
        #     init_string = ""
        #     if var.name in initialized_variable_names:
        #         init_string = ", *INIT_FROM_CKPT*"
        #     tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
        #                     init_string)
        
        ######################进入eval和predict之前,都经过这一步加载过程###################

        if mode == tf.estimator.ModeKeys.TRAIN:
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            train_op = optimizer.minimize(loss)
            train_op = tf.group([train_op, update_ops])
            return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

        if mode == tf.estimator.ModeKeys.EVAL:
            metrics = {"accuracy": tf.metrics.accuracy(features['label'], pred)}
            return tf.estimator.EstimatorSpec(mode, eval_metric_ops=metrics, loss=loss)

        predictions = {'predictions': pred}
        predictions.update({k: v for k, v in features.items()})

        return tf.estimator.EstimatorSpec(mode, predictions=predictions)

    return tf.estimator.Estimator(_model_fn, model_dir=model_dir, config=config,
                                  params={"lr": lr, "init_checkpoint": init_checkpoint})


def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
    """Compute the union of the current variables and checkpoint variables."""
    assignment_map = {}
    initialized_variable_names = {}

    name_to_variable = collections.OrderedDict()
    for var in tvars:
        name = var.name
        m = re.match("^(.*):\\d+$", name)
        if m is not None:
            name = m.group(1)
        name_to_variable[name] = var

    init_vars = tf.train.list_variables(init_checkpoint)

    assignment_map = collections.OrderedDict()
    for x in init_vars:
        (name, var) = (x[0], x[1])
        if name not in name_to_variable:
            continue
        assignment_map[name] = name
        initialized_variable_names[name] = 1
        initialized_variable_names[name + ":0"] = 1

    return (assignment_map, initialized_variable_names)

你可能感兴趣的:(python,tensorflow,tensorflow,深度学习,python)