batch_normalization一般是用在进入网络之前,它的作用是可以将每层网络的输入的数据分布变成正态分布,有利于网络的稳定性,加快收敛。
具体的公式如下: γ ( x − μ ) σ 2 + ϵ + β \frac{\gamma(x-\mu)}{\sqrt{\sigma^2+\epsilon}}+\beta σ2+ϵγ(x−μ)+β
其中 γ \gamma γ和 β \beta β是决定最终的正态分布,分别影响了方差和均值, ϵ \epsilon ϵ是为了避免出现分母为0的情况
在真实的使用中,均值 μ \mu μ和标准差 σ \sigma σ是由历史累计样本和当前批次样本来共同决定的:
μ = m o m e n t u m ∗ μ + ( 1 − m o m e n t u m ) ∗ μ b a t c h \mu=momentum*\mu+(1-momentum)*\mu_{batch} μ=momentum∗μ+(1−momentum)∗μbatch
σ = m o m e n t u m ∗ σ + ( 1 − m o m e n t u m ) ∗ σ b a t c h \sigma=momentum*\sigma+(1-momentum)*\sigma_{batch} σ=momentum∗σ+(1−momentum)∗σbatch
μ b a t c h \mu_{batch} μbatch表示当前批次样本的均值
在tensorflow中,推荐的api是
tf.layers.batch_normalization(
inputs, axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True,
beta_initializer=tf.zeros_initializer(),
gamma_initializer=tf.ones_initializer(),
moving_mean_initializer=tf.zeros_initializer(),
moving_variance_initializer=tf.ones_initializer(), beta_regularizer=None,
gamma_regularizer=None, beta_constraint=None, gamma_constraint=None,
training=False, trainable=True, name=None, reuse=None, renorm=False,
renorm_clipping=None, renorm_momentum=0.99, fused=None, virtual_batch_size=None,
adjustment=None
)
看几个关键的参数:
γ \gamma γ和 β \beta β是可训练变量,存放于tf.GraphKeys.TRAINABLE_VARIABLES
而均值和方差则不是训练变量,只能在tf.GraphKeys.GLOBAL_VARIABLES
中,并且更新过程存放于tf.GraphKeys.UPDATE_OPS
所以,最关键的点,也是最容易出问题的,就是:
那么,在训练的时候,需要将更新过程加入到train_op中:
x_norm = tf.layers.batch_normalization(x, training=True)
# ...
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_op = optimizer.minimize(loss)
train_op = tf.group([train_op, update_ops])
由于均值和方差是GLOBAL_VARIABLES,但是tensorflow默认只保存TRAINABLE_VARIABLES,所以,我们需要设置将所有变量保存起来,即GLOBAL_VARIABLES
sess = tf.Session()
saver = tf.train.Saver(tf.global_variables())
saver.save(sess, "your_path")
如果,模型正确保存了全局变量GLOBAL_VARIABLES,那么预测阶段,即可加载已经训练有素的batch_normalzation相关的参数;
但是,除此之外,还要将training设为False,将均值和方差固定住。
x_norm = tf.layers.batch_normalization(x, training=False)
# ...
saver = tf.train.Saver(tf.global_variables())
saver.restore(sess, "your_path")
如果你使用的是高阶API:estimator进行训练的话,那么就比较麻烦,因为它的session没有暴露出来,你没办法直接使用,需要换个方式:
def model_fn_build(init_checkpoint=None, lr=0.001, model_dir=None):
def _model_fn(features, labels, mode, params):
x = features['inputs']
y = features['labels']
#####################在这里定义你自己的网络模型###################
x_norm = tf.layers.batch_normalization(x, training=mode == tf.estimator.ModeKeys.TRAIN)
pre = tf.layers.dense(x_norm, 1)
loss = tf.reduce_mean(tf.pow(pre - y, 2), name='loss')
######################在这里定义你自己的网络模型###################
lr = params['lr']
######################进入eval和predict之前,都经过这一步加载过程###################
# 加载保存的模型
# 为了加载batch_normalization的参数,需要global_variables
tvars = tf.global_variables()
initialized_variable_names = {}
if params['init_checkpoint'] is not None or tf.train.latest_checkpoint(model_dir) is not None:
checkpoint = params['init_checkpoint'] or tf.train.latest_checkpoint(model_dir)
(assignment_map, initialized_variable_names
) = get_assignment_map_from_checkpoint(tvars, checkpoint)
tf.train.init_from_checkpoint(checkpoint, assignment_map)
# tf.logging.info("**** Trainable Variables ****")
# for var in tvars:
# init_string = ""
# if var.name in initialized_variable_names:
# init_string = ", *INIT_FROM_CKPT*"
# tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
# init_string)
######################进入eval和predict之前,都经过这一步加载过程###################
if mode == tf.estimator.ModeKeys.TRAIN:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_op = optimizer.minimize(loss)
train_op = tf.group([train_op, update_ops])
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
if mode == tf.estimator.ModeKeys.EVAL:
metrics = {"accuracy": tf.metrics.accuracy(features['label'], pred)}
return tf.estimator.EstimatorSpec(mode, eval_metric_ops=metrics, loss=loss)
predictions = {'predictions': pred}
predictions.update({k: v for k, v in features.items()})
return tf.estimator.EstimatorSpec(mode, predictions=predictions)
return tf.estimator.Estimator(_model_fn, model_dir=model_dir, config=config,
params={"lr": lr, "init_checkpoint": init_checkpoint})
def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
"""Compute the union of the current variables and checkpoint variables."""
assignment_map = {}
initialized_variable_names = {}
name_to_variable = collections.OrderedDict()
for var in tvars:
name = var.name
m = re.match("^(.*):\\d+$", name)
if m is not None:
name = m.group(1)
name_to_variable[name] = var
init_vars = tf.train.list_variables(init_checkpoint)
assignment_map = collections.OrderedDict()
for x in init_vars:
(name, var) = (x[0], x[1])
if name not in name_to_variable:
continue
assignment_map[name] = name
initialized_variable_names[name] = 1
initialized_variable_names[name + ":0"] = 1
return (assignment_map, initialized_variable_names)