tensorflow的saver

阅读更多
《tensorflow机器学习实战指南》的源码
https://github.com/nfmcclure/tensorflow_cookbook
全都是.ipynb 文件

pip install --upgrade pip
pip install jupyter notebook
jupyter notebook



参考官方文档
https://www.tensorflow.org/programmers_guide/saved_model?hl=zh-cn





save把变量都存下来
savetest.py
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-2)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
	sess.run(init_op)
	# Do some work with the model.
	inc_v1.op.run()
	dec_v2.op.run()
	# Save the variables to disk.
	save_path = saver.save(sess, "./tmp/model.ckpt")
	print("Model saved in path: %s" % save_path)
	print("v1 : %s" % v1.eval())
	print("v2 : %s" % v2.eval())




restore把存都变量都取出来
虽然定义来相同的 变量,restore都时候会覆盖掉

如果新程序便利多余save都程序,会报找不到变量都错误
storetest.py
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.reset_default_graph()

# Create some variables.

v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
#v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
#v3 = tf.get_variable("v3", shape=[7], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
#dec_v2 = v2.assign(v2-3)
#dec_v3 = v3.assign(v3-7)

init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
	sess.run(init_op)
    #Do some work with the model.
	inc_v1.op.run()
	#dec_v2.op.run()
	#dec_v3.op.run()

	# Restore variables from disk.
	saver.restore(sess, "./tmp/model.ckpt")

	print("Model restored.")
	# Check the values of the variables
	print("v1 : %s" % v1.eval())
#	print("v2 : %s" % v2.eval())
#	print("v3 : %s" % v3.eval())



如果想看所有变量,可以打印
# import the inspect_checkpoint library
from tensorflow.python.tools import inspect_checkpoint as chkp

# print all tensors in checkpoint file
chkp.print_tensors_in_checkpoint_file("./tmp/model.ckpt", tensor_name='', all_tensors=True)
print("------")
# tensor_name:  v1
# [ 1.  1.  1.]
# tensor_name:  v2
# [-1. -1. -1. -1. -1.]

# print only tensor v1 in checkpoint file
chkp.print_tensors_in_checkpoint_file("./tmp/model.ckpt", tensor_name='v1', all_tensors=False)

print("------")

# tensor_name:  v1
# tensor_name:  v1
# [ 1.  1.  1.]

# print only tensor v2 in checkpoint file
chkp.print_tensors_in_checkpoint_file("./tmp/model.ckpt", tensor_name='v2', all_tensors=False)

# tensor_name:  v2
# [-1. -1. -1. -1. -1.]

你可能感兴趣的:(tensorflow)