variables_to_restore是为了在保持模型的时候方便使用滑动平均的参数,如果不使用这个保存,那模型就会保存所以参数,除非你提前设定,就是在保存的时候指定保存变量也是可以的,比如saver = tf.train.Saver([v])这样就可以指定保存变量v,在模型导入的时候只有这个变量会被导入。
比如:
import tensorflow as tf;
import numpy as np;
import matplotlib.pyplot as plt;
v = tf.Variable(tf.constant(0.0, dtype=tf.float32), name='v')
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_average_op = ema.apply(tf.all_variables())
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
sess.run(tf.assign(v, 10.0))
sess.run(maintain_average_op)
saver.save(sess, '/home/penglu/Desktop/lp/model.ckpt')
模型导入:
import tensorflow as tf;
import numpy as np;
import matplotlib.pyplot as plt;
v = tf.Variable(tf.constant(0.0, dtype=tf.float32), name='v')
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_average_op = ema.apply(tf.all_variables())
saver = tf.train.Saver()
with tf.Session() as sess:
# sess.run(tf.initialize_all_variables())
# sess.run(tf.assign(v, 10.0))
# sess.run(maintain_average_op)
# saver.save(sess, '/home/penglu/Desktop/lp/model.ckpt')
saver.restore(sess, '/home/penglu/Desktop/lp/model.ckpt')
print sess.run(ema.average(v))
print sess.run(v)
输出:
0.0999999
10.0
这样不是很方便,因为我再次导入模型,变量v的值我不用,并且想要用计算后的值替代v,这样在模型被导入就方便就算
下面代码显示如何使用:
import tensorflow as tf;
import numpy as np;
import matplotlib.pyplot as plt;
v = tf.Variable(tf.constant(0.0, dtype=tf.float32), name='v')
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_average_op = ema.apply(tf.all_variables())
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
sess.run(tf.assign(v, 10.0))
sess.run(maintain_average_op)
saver.save(sess, '/home/penglu/Desktop/lp/model.ckpt')
print sess.run(v)
print sess.run(ema.average(v))
# saver.restore(sess, '/home/penglu/Desktop/lp/model.ckpt')
# print sess.run(v)
输出:
10.0
0.0999999
导入模型的时候tf.train.Saver函数要变化一下,变为tf.train.Saver(ema.variables_to_restore()),代码如下:
import tensorflow as tf;
import numpy as np;
import matplotlib.pyplot as plt;
v = tf.Variable(tf.constant(0.0, dtype=tf.float32), name='v')
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_average_op = ema.apply(tf.all_variables())
saver = tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:
# sess.run(tf.initialize_all_variables())
# sess.run(tf.assign(v, 10.0))
# sess.run(maintain_average_op)
# saver.save(sess, '/home/penglu/Desktop/lp/model.ckpt')
# print sess.run(v)
# print sess.run(ema.average(v))
saver.restore(sess, '/home/penglu/Desktop/lp/model.ckpt')
print sess.run(v)
输出:
0.0999999
注意:如果不变的话,那么输出就会是10!