tf.train.Saver函数的用法之保存全部变量和模型

 

用于保存模型,以后再用就可以直接导入模型进行计算,方便。

例如:

[python]  view plain  copy
  1. import tensorflow as tf;    
  2. import numpy as np;    
  3. import matplotlib.pyplot as plt;    
  4.   
  5. v1 = tf.Variable(tf.constant(1, shape=[1]), name='v1')  
  6. v2 = tf.Variable(tf.constant(2, shape=[1]), name='v2')  
  7.   
  8. result = v1 + v2  
  9.   
  10. init = tf.initialize_all_variables()  
  11.   
  12. saver = tf.train.Saver()  
  13.   
  14. with tf.Session() as sess:  
  15.     sess.run(init)  
  16.     saver.save(sess, "/home/penglu/Desktop/lp/model.ckpt")  
  17.     # saver.restore(sess, "/home/penglu/Desktop/lp/model.ckpt")  
  18.     # print sess.run(result)  
结果:


tf.train.Saver函数的用法之保存全部变量和模型_第1张图片



下次需要使用模型就可以用下面的代码:

[python]  view plain  copy
  1. import tensorflow as tf;    
  2. import numpy as np;    
  3. import matplotlib.pyplot as plt;    
  4.   
  5. v1 = tf.Variable(tf.constant(1, shape=[1]), name='v1')  
  6. v2 = tf.Variable(tf.constant(2, shape=[1]), name='v2')  
  7.   
  8. result = v1 + v2  
  9.   
  10. init = tf.initialize_all_variables()  
  11.   
  12. saver = tf.train.Saver()  
  13.   
  14. with tf.Session() as sess:  
  15.     saver.restore(sess, "/home/penglu/Desktop/lp/model.ckpt")  
  16.     print sess.run(result)  
[python]  view plain  copy
  1.   
[python]  view plain  copy
  1. 或者这个代码:  
[python]  view plain  copy
  1. import tensorflow as tf;    
  2. import numpy as np;    
  3. import matplotlib.pyplot as plt;  
[python]  view plain  copy
  1. saver = tf.train.import_meta_graph('/home/penglu/Desktop/lp/model.ckpt.meta')  
  2. with tf.Session() as sess:  
  3. "white-space:pre">  saver.restore(sess, "/home/penglu/Desktop/lp/model.ckpt")  
  4. "white-space:pre">  print sess.run(tf.get_default_graph().get_tensor_by_name('add:0'))  

你可能感兴趣的:(python,tensorflow)