tensorflow中模型的保存和加载通过tf.train.Saver类实现,具体的保存的内容,保存频率,保存模型的数量等只要传入指定的参数就能实现。具体的保存使用到两个该类的方法,save 和 restore。
官方指南:https://www.tensorflow.org/guide/saved_model
Saver的构造函数:
__init__(
var_list=None, # 想要保存或者恢复的变量,列表或者字典的形式,默认全保存
reshape=False,
sharded=False,
max_to_keep=5, # 指定最多保存多少个最新的模型(checkpoint文件)
keep_checkpoint_every_n_hours=10000.0, #设定训练多久保存一次
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=tf.train.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False, #记录模型的相对路径
filename=None # 保存文件名
)
更详细的参数参考官方的文件:https://www.tensorflow.org/api_docs/python/tf/train/Saver
在会话看开始之前创建一个实例:
# 不加参数时,默认保存所有变量
saver = tf.train.Saver()
# 列表或者字典的形式传入想要保存的变量
# saver = tf.train.Saver([v1])
# 可以指定想要保存的模型个数
# saver = tf.train.Saver(max_to_keep=4,)
# 指定训练过程的什么时候保存,如2小时保存一次
# saver = tf.train.Saver(keep_checkpoint_every_n_hours=2)
使用类方法直接保存到指定目录, 下面看例子。
这个函数有8个参数:
save(
sess, # 会话
save_path, # 包含模型名字的绝对路径
global_step=None, #
latest_filename=None, # 最新的检查点文件名,默认就是‘checkpoint’
meta_graph_suffix='meta', # MetaGraphDef元图后缀,默认为“meta”
write_meta_graph=True, # 是否保存元图数据
write_state=True
)
该方法返回保存checkpoint的路径。
官方例子:
import tensorflow as tf
# Create some variables.
v1 = tf.get_variable(name="v1", shape=[3], initializer=tf.zeros_initializer)
v2 = tf.get_variable(name="v2", shape=[5], initializer=tf.zeros_initializer)
inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# 不加参数时,默认保存所有变量
saver = tf.train.Saver()
# 列表或者字典的形式传入想要保存的变量
# saver = tf.train.Saver([v1])
# 可以指定想要保存的模型个数
# saver = tf.train.Saver(max_to_keep=4,)
# 指定训练过程的什么时候保存,如2小时保存一次
# saver = tf.train.Saver(keep_checkpoint_every_n_hours=2)
# launch the model, initialize the variables, do some work, and save the variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
inc_v1.op.run()
dec_v2.op.run()
# 打印输出
print(v1.eval())
print(v2.eval())
# Save the variables to disk.
saver.save(sess, "./save_restore_model/my_test") # 保存到当前目录下的save_restore_model文件夹,文件名为‘my_test.XXX’
# save model every 1000 iterations
# saver.save(sess, "test_save_restore_model", global_step=1000, write_meta_graph=False)
# 输出
[1. 1. 1.]
[-1. -1. -1. -1. -1.]
保存的结果
生成4个文件 checkpoint, my_test.data-00000-of-00001, my_test.index, my_test.meta.
使用Saver的restore()方法直接从保存的文件中恢复变量的值。
restore方法只有两个参数: 恢复参数的会话、模型保存的路径。
restore(
sess,
save_path
)
在恢复原来的变量之前,我们需要新建与保存前一模一样的变量。比如上面的例子,有两个名为v1,v2的变量,
v1 = tf.get_variable(name="v1", shape=[3], initializer=tf.zeros_initializer)
v2 = tf.get_variable(name="v2", shape=[5], initializer=tf.zeros_initializer)
我们在恢复时,也要创建这样的两个变量:
v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
官方例子:
import tensorflow as tf
tf.reset_default_graph()
# Create some variables.
v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
# 恢复所有变量
saver = tf.train.Saver()
# 恢复变量v2
# saver = tf.train.Saver({"v2": v2})
# Use the saver object normally after that.
with tf.Session() as sess:
# Initialize v1 since the saver will not.
v1.initializer.run()
saver.restore(sess, "./save_restore_model/my_test")
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())
#输出
v1 : [1. 1. 1.]
v2 : [-1. -1. -1. -1. -1.]
在实际中,当我们要使用别人的模型且模型很大是,手动建立一个变量将变得困难。其实我们在保存模型的时候已经将模型的图存放在.meta的文件中,我们可以使用
tf.train.import_meta_graph直接从文件中恢复图。
tf.train.import_meta_graph(
meta_graph_or_file,
clear_devices=False,
import_scope=None,
**kwargs
)
例子:
import tensorflow as tf
tf.reset_default_graph()
# # Create some variables.
# v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
# v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
#
# # 恢复所有变量
# saver = tf.train.Saver()
# # 恢复变量v2
# # saver = tf.train.Saver({"v2": v2})
#
# # Use the saver object normally after that.
# with tf.Session() as sess:
# # Initialize v1 since the saver will not.
# v1.initializer.run()
# saver.restore(sess, "./save_restore_model/my_test")
#
# print("v1 : %s" % v1.eval())
# print("v2 : %s" % v2.eval())
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('./save_restore_model/my_test.meta')
new_saver.restore(sess, "./save_restore_model/my_test")
print(sess.run('v1:0'))
#输出
[1. 1. 1.]
可以使用tf.trainlatest_checkpoint()函数,寻找最近保存的checkpoint文件。
tf.train.latest_checkpoint(
checkpoint_dir, # 变量保存的路径
latest_filename=None # 最近保存的checkpoint文件
)
例子:
import tensorflow as tf
tf.reset_default_graph()
# # Create some variables.
# v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
# v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
#
# # 恢复所有变量
# saver = tf.train.Saver()
# # 恢复变量v2
# # saver = tf.train.Saver({"v2": v2})
#
# # Use the saver object normally after that.
# with tf.Session() as sess:
# # Initialize v1 since the saver will not.
# v1.initializer.run()
# saver.restore(sess, "./save_restore_model/my_test")
#
# print("v1 : %s" % v1.eval())
# print("v2 : %s" % v2.eval())
#
# with tf.Session() as sess:
# new_saver = tf.train.import_meta_graph('./save_restore_model/my_test.meta')
# new_saver.restore(sess, "./save_restore_model/my_test")
# print(sess.run('v1:0'))
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('./save_restore_model/my_test.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./save_restore_model'))
print(sess.run('v1:0'))
# 输出
[1. 1. 1.]
或者使用 graph.get_tensor_by_name()方法来获取保存的变量,参考5中的例子。
上面我们直接就把v1打印出来是因为我们知道我们保存了名字为v1这个变量,如果我们用别人的模型,怎么获取哪些变量的名?
使用 inspect_checkpoint 库快速检查某个检查点中的变量。
# import the inspect_checkpoint library
from tensorflow.python.tools import inspect_checkpoint as chkp
# print all tensors in checkpoint file
chkp.print_tensors_in_checkpoint_file("./save_restore_model/my_test", tensor_name='', all_tensors=True)
# tensor_name: v1
# [ 1. 1. 1.]
# tensor_name: v2
# [-1. -1. -1. -1. -1.]
# print only tensor v1 in checkpoint file
chkp.print_tensors_in_checkpoint_file("./save_restore_model/my_test", tensor_name='v1', all_tensors=False)
# tensor_name: v1
# [ 1. 1. 1.]
# print only tensor v2 in checkpoint file
chkp.print_tensors_in_checkpoint_file("./save_restore_model/my_test", tensor_name='v2', all_tensors=False)
# tensor_name: v2
# [-1. -1. -1. -1. -1.]
# 输出
tensor_name: v1
[1. 1. 1.]
tensor_name: v2
[-1. -1. -1. -1. -1.]
tensor_name: v1
[1. 1. 1.]
tensor_name: v2
[-1. -1. -1. -1. -1.]
出自 muqiusangyang :https://my.oschina.net/u/3800567/blog/1637800
import tensorflow as tf
#从ckpt文件中获取variable变量的名字
def get_trainable_variables_name_from_ckpt(meta_graph_path,ckpt_path):
#定义一个新的graph
graph = tf.Graph()
#将其设置为默认图:
with graph.as_default():
with tf.Session() as session:
#加载计算图
saver = tf.train.import_meta_graph(meta_graph_path)
#加载模型到session中关联的graph中,即将模型文件中的计算图加载到这里的graph中
saver.restore(session, ckpt_path)
v_names = []
#获取session所关联的图中可被训练的variable
#使用tf.trainable_variables()获取variable时,只有在该函数前面定义的variable才会被获取到
#在其后面定义不会被获取到,
for v in tf.trainable_variables():
v_names.append(v)
return v_names
#利用pywrap_tensorflow获取ckpt文件中的所有变量,得到的是variable名字与shape的一个map
if __name__ == '__main__':
meta_graph_path = "./save_restore_model/my_test.meta"
ckpt_path = tf.train.latest_checkpoint('./save_restore_model')
v_names = get_trainable_variables_name_from_ckpt(meta_graph_path, ckpt_path)
print(v_names)
# 输出
[, ]
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
def get_all_variables_name_from_ckpt(ckpt_path):
reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
all_var = reader.get_variable_to_shape_map()
#reader.get_variable_to_dtype_map()
return all_var
if __name__ == '__main__':
meta_graph_path = "./save_restore_model/my_test.meta"
ckpt_path = tf.train.latest_checkpoint('./save_restore_model')
all_var = get_all_variables_name_from_ckpt(ckpt_path)
print(all_var)
# 输出
{'v1': [3], 'v2': [5]}
def copy_var_from_ckpt(session, dst_var_name, dst_var,ckpt_path, meta_graph_path):
#定义一个新的graph
graph = tf.Graph()
#将其设置为默认图:
with graph.as_default():
with tf.Session() as sess:
#加载计算图
saver = tf.train.import_meta_graph(meta_graph_path)
#加载模型到session中关联的graph中,即将模型文件中的计算图加载到这里的graph中
saver.restore(sess,ckpt_path)
v_names = []
#获取session所关联的图中可被训练的variable
#使用tf.trainable_variables()获取variable时,只有在该函数前面定义的variable才会被获取到
#在其后面定义不会被获取到,
for v in tf.trainable_variables():
v_names.append(v)
if dst_var_name in v_names:
#获取tensor
tensor = graph.get_tensor_by_name(dst_var_name)
#获取tensor的值,即网络中权值
weight = sess.run(tensor)
#拷贝权值,注意,需要使用dst_var所在的session
#使用assign操作来拷贝dst_var是一个variable,weight是一个array
session.run(dst_var.assign(weight))
图创建完成之后,可以直接保存下来
使用write_graph/import_graph_def保存和加载。
另外还有MetaGraph导出和导入export_meta_graph/import_meta_graph方法。
代码出自:A quick complete tutorial to save and restore Tensorflow models,by ANKIT SACHAN
import tensorflow as tf
#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}
#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#Create a saver object which will save all the variables
saver = tf.train.Saver()
#Run the operation by feeding input
print(sess.run(w4,feed_dict))
#Prints 24 which is sum of (w1+w2)*b1
#Now, save the graph
saver.save(sess, './my_test_model/test',global_step=1000)
# 输出
24.0
import tensorflow as tf
sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('./my_test_model/test-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./my_test_model'))
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
print(sess.run(op_to_restore,feed_dict))
#This will print 60 which is calculated
#using new values of w1 and w2 and saved value of b1.
# 输出
60.0
import tensorflow as tf
sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)
print sess.run(add_on_op,feed_dict)
#This will print 120.
参考下面的文章:
tensorflow保存和恢复模型的两种方法介绍第三部分:基于mnist写一个例子来说明如何使用save/restore方法保存和恢复模型
参考
https://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
https://zhuanlan.zhihu.com/p/31417693
https://maozezhong.github.io/2018/05/03/tensorflow/tensorflow保存模型及重载模型/
https://blog.csdn.net/u011106767/article/details/80052372
https://blog.csdn.net/m0_37924639/article/details/79340294