https://github.com/nfmcclure/tensorflow_cookbook
全都是.ipynb 文件
pip install --upgrade pip
pip install jupyter notebook
jupyter notebook
参考官方文档
https://www.tensorflow.org/programmers_guide/saved_model?hl=zh-cn
save把变量都存下来
savetest.py
import tensorflow as tf import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) inc_v1 = v1.assign(v1+1) dec_v2 = v2.assign(v2-2) # Add an op to initialize the variables. init_op = tf.global_variables_initializer() # Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, initialize the variables, do some work, and save the # variables to disk. with tf.Session() as sess: sess.run(init_op) # Do some work with the model. inc_v1.op.run() dec_v2.op.run() # Save the variables to disk. save_path = saver.save(sess, "./tmp/model.ckpt") print("Model saved in path: %s" % save_path) print("v1 : %s" % v1.eval()) print("v2 : %s" % v2.eval())
restore把存都变量都取出来
虽然定义来相同的 变量,restore都时候会覆盖掉
如果新程序便利多余save都程序,会报找不到变量都错误
storetest.py
import tensorflow as tf import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' tf.reset_default_graph() # Create some variables. v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) #v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) #v3 = tf.get_variable("v3", shape=[7], initializer = tf.zeros_initializer) inc_v1 = v1.assign(v1+1) #dec_v2 = v2.assign(v2-3) #dec_v3 = v3.assign(v3-7) init_op = tf.global_variables_initializer() # Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, use the saver to restore variables from disk, and # do some work with the model. with tf.Session() as sess: sess.run(init_op) #Do some work with the model. inc_v1.op.run() #dec_v2.op.run() #dec_v3.op.run() # Restore variables from disk. saver.restore(sess, "./tmp/model.ckpt") print("Model restored.") # Check the values of the variables print("v1 : %s" % v1.eval()) # print("v2 : %s" % v2.eval()) # print("v3 : %s" % v3.eval())
如果想看所有变量,可以打印
# import the inspect_checkpoint library from tensorflow.python.tools import inspect_checkpoint as chkp # print all tensors in checkpoint file chkp.print_tensors_in_checkpoint_file("./tmp/model.ckpt", tensor_name='', all_tensors=True) print("------") # tensor_name: v1 # [ 1. 1. 1.] # tensor_name: v2 # [-1. -1. -1. -1. -1.] # print only tensor v1 in checkpoint file chkp.print_tensors_in_checkpoint_file("./tmp/model.ckpt", tensor_name='v1', all_tensors=False) print("------") # tensor_name: v1 # tensor_name: v1 # [ 1. 1. 1.] # print only tensor v2 in checkpoint file chkp.print_tensors_in_checkpoint_file("./tmp/model.ckpt", tensor_name='v2', all_tensors=False) # tensor_name: v2 # [-1. -1. -1. -1. -1.]