TensorFlow中关于saver读取中MovingAverage的一些注意事项

saver中保存滑动平均模型中,当我们直接定义一个滑动平均类的操作后,会自动生成变量列表中所对应的shaddow variables, 具体细节代码如下:

#part 1

v = tf.Variable(0, dtype=tf.float32, name="v")
#创建滑动平均的类,给定初始衰减率0.5
ema = tf.train.ExponentialMovingAverage(0.5) 
#定义一个更新变量滑动平均操作。注意这里需要给定一个列表,每次执行这个操作的时候列表中的变量都会被更新 
ema_op = ema.apply(tf.global_variables())
#请注意,当我们定义好了变量的滑动平均操作之后,会自动生成一个"v/ExponentialMovingAverage"的影子变量
for variable in tf.global_variables():
    print(variable.name)
#输出:
#v:0
#v/ExponentialMovingAverage:0

#给变量赋值看v和v/ExponentialMovingAverage的取值并保存模型:
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.assign(v, 10))
    sess.run(ema_op)
    saver.save(sess, "./save_model/model.ckpt")
    print(sess.run([v, ema.avergae(v)]))       #其中ema.average(v)就代表了v的shaddow variable
#输出:
#[10.0, 5.0]

#part2 从saver中读取数据

v = tf.Variable(3.0, dtype=tf.float32, name="v")
saver = tf.train.Saver({"v/ExponentialMovingAverage":v})      #把上面所保存的滑动平均值加载给v值  
with tf.Session() as sess:
    saver.restore(sess, "./save_model/model.ckpt") 
    print(sess.run(v))
#输出:
#5.0             注意,这个值是从文件中读取的v的滑动平均的值,如果想加载v的原值,直接tf.train.Saver({"v": v})或者tf.train.Saver()就行

#part3  当变量变得很多的时候,通过字典的方式来加载滑动平均值就显得不可能了,所以在TensorFlow中提供了variables_to_restore函数
#来生成tf.train.Saver()类所需的变量重命名字典, 样例如下:

import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")
ema = tf.train.ExponentialMovingAverage(0.5)
#通过使用variables_to_restore函数来生成上面tf.train.Saver()所需的字典
print(ema.variables_to_restore())
#输出:
#{"v/ExponentialMovingAverage": }

saver = tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:
    saver.restore(sess, "./save_model/model.ckpt")
    print(sess.run(v)          #输出 5.0,  也就是原来模型中变量v的滑动平均值
总结:当我们想运用以及训练好的权重的滑动平均值来预测数据的时候(滑动平均值可以让神经网络模型变得更加健壮)直接通过用variables_to_restore来生成我们所需的变量重命名字典,这样效率跟自己手动定义字典相比大幅提升,而且代码简洁。







你可能感兴趣的:(TensorFlow学习心得)