版本: tensorflow-1.8.0
代码:Github
使用tf.train.Saver模块, 保存路径的URL名称一定要*.ckpt。
import tensorflow as tf
v1 = tf.get_variable("v1", shape=[1], initializer=tf.random_normal_initializer(mean=0.0, stddev=1.0), dtype=tf.float32)
v2 = tf.get_variable("v2", shape=[1], initializer=tf.random_normal_initializer(mean=0.0, stddev=1.0), dtype=tf.float32)
result = v1 + v2
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
saver.save(sess, "/home/jagger/workspace/tmp/model.ckpt")
模型会保存会在“/home/jagger/workspace/tmp/“目录下出现三个文件
其中.meta保存的是模型图结构,.index保存的是模型参数的索引,*.data保存的是模型参数具体数值。
如果要保存不同训练次数的模型,可以这样
saver.save(sess, "/home/jagger/workspace/tmp/model.ckpt", global_step=step)
这时候就会自动在.ckpt加上-次数,例如:
如果想只保存模型结构(Graph)
saver = tf.train.Saver()
saver.export_meta_graph("/home/jagger/workspace/tmp/model.ckpt.meta", as_text=True) # 可用编辑器打开,为json格式
从.ckpt文件中查看模型参数值
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
# print all variables in checkpoint file
print("=========All Variables==========")
print_tensors_in_checkpoint_file("/home/jagger/workspace/tmp/model.ckpt", tensor_name=None, all_tensors=True, all_tensor_names=True)
# print only tensor v1 in checkpoint file
print("=========V1==========")
print_tensors_in_checkpoint_file("/home/jagger/workspace/tmp/model.ckpt", tensor_name='v1', all_tensors=False, all_tensor_names=False)
# print only tensor v2 in checkpoint file
print("=========V2==========")
print_tensors_in_checkpoint_file("/home/jagger/workspace/tmp/model.ckpt", tensor_name='v2', all_tensors=False, all_tensor_names=False)
输出:
=========All Variables==========
tensor_name: v1
[-0.46169969]
tensor_name: v2
[ 0.40403476]
=========V1==========
tensor_name: v1
[-0.46169969]
=========V2==========
tensor_name: v2
[ 0.40403476]
如果模型参数文件是.ckpt-1000.index和.ckpt-1000.data形式,则输入的url也要加上-次数,例如
print_tensors_in_checkpoint_file("/home/jagger/workspace/tmp/model.ckpt-1000",
tensor_name=None, all_tensors=True, all_tensor_names=True)
只载入模型参数,模型结构需要自己构建:
import tensorflow as tf
v1 = tf.get_variable("other-v1", shape=[1], initializer=tf.random_normal_initializer(mean=0.0, stddev=1.0), dtype=tf.float32)
v2 = tf.get_variable("other-v2", shape=[1], initializer=tf.random_normal_initializer(mean=0.0, stddev=1.0), dtype=tf.float32)
result = v1 + v2
saver = tf.train.Saver({"v1": v1, "v2": v2}) # 指定对应的参数,缺省则按参数名载入
with tf.Session() as sess:
saver.restore(sess, "/home/jagger/workspace/tmp/model.ckpt")
print(sess.run(result))
既载入模型结构(Graph),也载入参数值
import tensorflow as tf
# importing graph
saver = tf.train.import_meta_graph("/home/jagger/workspace/tmp/model.ckpt.meta")
with tf.Session() as sess:
# loading variable value to sess
saver.restore(sess, "/home/jagger/workspace/tmp/model.ckpt")
result = tf.get_default_graph().get_tensor_by_name("add:0")
print(sess.run(result))