variables_to_restore函数的用法

variables_to_restore是为了在保持模型的时候方便使用滑动平均的参数,如果不使用这个保存,那模型就会保存所以参数,除非你提前设定,就是在保存的时候指定保存变量也是可以的,比如saver = tf.train.Saver([v])这样就可以指定保存变量v,在模型导入的时候只有这个变量会被导入。

比如:

import tensorflow as tf;  
import numpy as np;  
import matplotlib.pyplot as plt;  

v = tf.Variable(tf.constant(0.0, dtype=tf.float32), name='v')

ema = tf.train.ExponentialMovingAverage(0.99)
maintain_average_op = ema.apply(tf.all_variables())

saver = tf.train.Saver()
with tf.Session() as sess:
	sess.run(tf.initialize_all_variables())

	sess.run(tf.assign(v, 10.0))
	sess.run(maintain_average_op)
	saver.save(sess, '/home/penglu/Desktop/lp/model.ckpt')
模型导入:

import tensorflow as tf;  
import numpy as np;  
import matplotlib.pyplot as plt;  

v = tf.Variable(tf.constant(0.0, dtype=tf.float32), name='v')

ema = tf.train.ExponentialMovingAverage(0.99)
maintain_average_op = ema.apply(tf.all_variables())

saver = tf.train.Saver()
with tf.Session() as sess:
	# sess.run(tf.initialize_all_variables())

	# sess.run(tf.assign(v, 10.0))
	# sess.run(maintain_average_op)
	# saver.save(sess, '/home/penglu/Desktop/lp/model.ckpt')

	saver.restore(sess, '/home/penglu/Desktop/lp/model.ckpt')
	print sess.run(ema.average(v))
	print sess.run(v)
输出:

0.0999999
10.0

这样不是很方便,因为我再次导入模型,变量v的值我不用,并且想要用计算后的值替代v,这样在模型被导入就方便就算

下面代码显示如何使用:

import tensorflow as tf;  
import numpy as np;  
import matplotlib.pyplot as plt;  

v = tf.Variable(tf.constant(0.0, dtype=tf.float32), name='v')

ema = tf.train.ExponentialMovingAverage(0.99)
maintain_average_op = ema.apply(tf.all_variables())

saver = tf.train.Saver()
with tf.Session() as sess:
	sess.run(tf.initialize_all_variables())

	sess.run(tf.assign(v, 10.0))
	sess.run(maintain_average_op)
	saver.save(sess, '/home/penglu/Desktop/lp/model.ckpt')
	print sess.run(v)
	print sess.run(ema.average(v))

	# saver.restore(sess, '/home/penglu/Desktop/lp/model.ckpt')
	# print sess.run(v)
输出:

10.0
0.0999999


导入模型的时候tf.train.Saver函数要变化一下,变为tf.train.Saver(ema.variables_to_restore()),代码如下:

import tensorflow as tf;  
import numpy as np;  
import matplotlib.pyplot as plt;  

v = tf.Variable(tf.constant(0.0, dtype=tf.float32), name='v')

ema = tf.train.ExponentialMovingAverage(0.99)
maintain_average_op = ema.apply(tf.all_variables())

saver = tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:
	# sess.run(tf.initialize_all_variables())

	# sess.run(tf.assign(v, 10.0))
	# sess.run(maintain_average_op)
	# saver.save(sess, '/home/penglu/Desktop/lp/model.ckpt')
	# print sess.run(v)
	# print sess.run(ema.average(v))

	saver.restore(sess, '/home/penglu/Desktop/lp/model.ckpt')
	print sess.run(v)
输出:

0.0999999


注意:如果不变的话,那么输出就会是10!




你可能感兴趣的:(tensorflow用法)