Tensorflow 模型的保存和恢复

tensorflow中模型的保存和加载通过tf.train.Saver类实现,具体的保存的内容,保存频率,保存模型的数量等只要传入指定的参数就能实现。具体的保存使用到两个该类的方法,save 和 restore。
官方指南:https://www.tensorflow.org/guide/saved_model

1. tf.train.Saver类

Saver的构造函数:

__init__(
    var_list=None, # 想要保存或者恢复的变量,列表或者字典的形式,默认全保存
    reshape=False, 
    sharded=False,
    max_to_keep=5, # 指定最多保存多少个最新的模型(checkpoint文件)
    keep_checkpoint_every_n_hours=10000.0, #设定训练多久保存一次
    name=None,
    restore_sequentially=False,
    saver_def=None,
    builder=None,
    defer_build=False,
    allow_empty=False,
    write_version=tf.train.SaverDef.V2,
    pad_step_number=False,
    save_relative_paths=False,  #记录模型的相对路径
    filename=None # 保存文件名
)

更详细的参数参考官方的文件:https://www.tensorflow.org/api_docs/python/tf/train/Saver

在会话看开始之前创建一个实例:

# 不加参数时,默认保存所有变量
saver = tf.train.Saver()
# 列表或者字典的形式传入想要保存的变量
# saver = tf.train.Saver([v1])
# 可以指定想要保存的模型个数
# saver = tf.train.Saver(max_to_keep=4,)
# 指定训练过程的什么时候保存,如2小时保存一次
# saver = tf.train.Saver(keep_checkpoint_every_n_hours=2)

2. save()方法

使用类方法直接保存到指定目录, 下面看例子。
这个函数有8个参数:

save(
    sess, # 会话
    save_path, # 包含模型名字的绝对路径
    global_step=None, # 
    latest_filename=None, # 最新的检查点文件名,默认就是‘checkpoint’
    meta_graph_suffix='meta', # MetaGraphDef元图后缀,默认为“meta”
    write_meta_graph=True, # 是否保存元图数据
    write_state=True
)

该方法返回保存checkpoint的路径。
官方例子:

import tensorflow as tf

# Create some variables.
v1 = tf.get_variable(name="v1", shape=[3], initializer=tf.zeros_initializer)
v2 = tf.get_variable(name="v2", shape=[5], initializer=tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# 不加参数时,默认保存所有变量
saver = tf.train.Saver()
# 列表或者字典的形式传入想要保存的变量
# saver = tf.train.Saver([v1])
# 可以指定想要保存的模型个数
# saver = tf.train.Saver(max_to_keep=4,)
# 指定训练过程的什么时候保存,如2小时保存一次
# saver = tf.train.Saver(keep_checkpoint_every_n_hours=2)

# launch the model, initialize the variables, do some work, and save the variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  inc_v1.op.run()
  dec_v2.op.run()
  # 打印输出
  print(v1.eval())
  print(v2.eval())
  # Save the variables to disk.
  saver.save(sess, "./save_restore_model/my_test")  # 保存到当前目录下的save_restore_model文件夹,文件名为‘my_test.XXX’
  # save model every 1000 iterations
  # saver.save(sess, "test_save_restore_model", global_step=1000, write_meta_graph=False)

# 输出
[1. 1. 1.]
[-1. -1. -1. -1. -1.]

保存的结果
生成4个文件 checkpoint, my_test.data-00000-of-00001, my_test.index, my_test.meta.
Tensorflow 模型的保存和恢复_第1张图片

3. restore()方法

使用Saver的restore()方法直接从保存的文件中恢复变量的值。
restore方法只有两个参数: 恢复参数的会话、模型保存的路径。

restore(
    sess,
    save_path
)

在恢复原来的变量之前,我们需要新建与保存前一模一样的变量。比如上面的例子,有两个名为v1,v2的变量,

v1 = tf.get_variable(name="v1", shape=[3], initializer=tf.zeros_initializer)
v2 = tf.get_variable(name="v2", shape=[5], initializer=tf.zeros_initializer)

我们在恢复时,也要创建这样的两个变量:

v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)

官方例子:

import tensorflow as tf
tf.reset_default_graph()
# Create some variables.
v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)

# 恢复所有变量
saver = tf.train.Saver()
# 恢复变量v2
# saver = tf.train.Saver({"v2": v2})

# Use the saver object normally after that.
with tf.Session() as sess:
  # Initialize v1 since the saver will not.
  v1.initializer.run()
  saver.restore(sess, "./save_restore_model/my_test")

  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())
#输出
v1 : [1. 1. 1.]
v2 : [-1. -1. -1. -1. -1.]

在实际中,当我们要使用别人的模型且模型很大是,手动建立一个变量将变得困难。其实我们在保存模型的时候已经将模型的图存放在.meta的文件中,我们可以使用
tf.train.import_meta_graph直接从文件中恢复图。

tf.train.import_meta_graph(
    meta_graph_or_file,
    clear_devices=False,
    import_scope=None,
    **kwargs
)

例子:

import tensorflow as tf
tf.reset_default_graph()
# # Create some variables.
# v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
# v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
#
# # 恢复所有变量
# saver = tf.train.Saver()
# # 恢复变量v2
# # saver = tf.train.Saver({"v2": v2})
#
# # Use the saver object normally after that.
# with tf.Session() as sess:
#   # Initialize v1 since the saver will not.
#   v1.initializer.run()
#   saver.restore(sess, "./save_restore_model/my_test")
#
#   print("v1 : %s" % v1.eval())
#   print("v2 : %s" % v2.eval())

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('./save_restore_model/my_test.meta')
    new_saver.restore(sess, "./save_restore_model/my_test")
    print(sess.run('v1:0'))
#输出
[1. 1. 1.]

可以使用tf.trainlatest_checkpoint()函数,寻找最近保存的checkpoint文件。

tf.train.latest_checkpoint(
    checkpoint_dir, # 变量保存的路径
    latest_filename=None # 最近保存的checkpoint文件
)

例子:

import tensorflow as tf
tf.reset_default_graph()
# # Create some variables.
# v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
# v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
#
# # 恢复所有变量
# saver = tf.train.Saver()
# # 恢复变量v2
# # saver = tf.train.Saver({"v2": v2})
#
# # Use the saver object normally after that.
# with tf.Session() as sess:
#   # Initialize v1 since the saver will not.
#   v1.initializer.run()
#   saver.restore(sess, "./save_restore_model/my_test")
#
#   print("v1 : %s" % v1.eval())
#   print("v2 : %s" % v2.eval())
#
# with tf.Session() as sess:
#     new_saver = tf.train.import_meta_graph('./save_restore_model/my_test.meta')
#     new_saver.restore(sess, "./save_restore_model/my_test")
#     print(sess.run('v1:0'))

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('./save_restore_model/my_test.meta')
    new_saver.restore(sess, tf.train.latest_checkpoint('./save_restore_model'))
    print(sess.run('v1:0'))
# 输出
[1. 1. 1.]

或者使用 graph.get_tensor_by_name()方法来获取保存的变量,参考5中的例子。

3. 检查checkpoint中保存的变量

上面我们直接就把v1打印出来是因为我们知道我们保存了名字为v1这个变量,如果我们用别人的模型,怎么获取哪些变量的名?

第一种方法:

使用 inspect_checkpoint 库快速检查某个检查点中的变量。

# import the inspect_checkpoint library
from tensorflow.python.tools import inspect_checkpoint as chkp

# print all tensors in checkpoint file
chkp.print_tensors_in_checkpoint_file("./save_restore_model/my_test", tensor_name='', all_tensors=True)

# tensor_name:  v1
# [ 1.  1.  1.]
# tensor_name:  v2
# [-1. -1. -1. -1. -1.]

# print only tensor v1 in checkpoint file
chkp.print_tensors_in_checkpoint_file("./save_restore_model/my_test", tensor_name='v1', all_tensors=False)

# tensor_name:  v1
# [ 1.  1.  1.]

# print only tensor v2 in checkpoint file
chkp.print_tensors_in_checkpoint_file("./save_restore_model/my_test", tensor_name='v2', all_tensors=False)

# tensor_name:  v2
# [-1. -1. -1. -1. -1.]
# 输出
tensor_name:  v1
[1. 1. 1.]
tensor_name:  v2
[-1. -1. -1. -1. -1.]
tensor_name:  v1
[1. 1. 1.]
tensor_name:  v2
[-1. -1. -1. -1. -1.]

第二种方法

出自 muqiusangyang :https://my.oschina.net/u/3800567/blog/1637800

import tensorflow as tf
#从ckpt文件中获取variable变量的名字
def get_trainable_variables_name_from_ckpt(meta_graph_path,ckpt_path):
    #定义一个新的graph
    graph = tf.Graph()
    #将其设置为默认图:
    with graph.as_default():
        with tf.Session() as session:
            #加载计算图
            saver = tf.train.import_meta_graph(meta_graph_path)
            #加载模型到session中关联的graph中,即将模型文件中的计算图加载到这里的graph中
            saver.restore(session, ckpt_path)
            v_names = []
            #获取session所关联的图中可被训练的variable
            #使用tf.trainable_variables()获取variable时,只有在该函数前面定义的variable才会被获取到
            #在其后面定义不会被获取到,
            for v in tf.trainable_variables():
                v_names.append(v)
            return v_names
#利用pywrap_tensorflow获取ckpt文件中的所有变量,得到的是variable名字与shape的一个map

if __name__ == '__main__':
    meta_graph_path = "./save_restore_model/my_test.meta"
    ckpt_path = tf.train.latest_checkpoint('./save_restore_model')
    v_names = get_trainable_variables_name_from_ckpt(meta_graph_path, ckpt_path)
    print(v_names)
# 输出
[, ]

第三种方法

import tensorflow as tf


from tensorflow.python import pywrap_tensorflow
def get_all_variables_name_from_ckpt(ckpt_path):
    reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
    all_var = reader.get_variable_to_shape_map()
    #reader.get_variable_to_dtype_map()
    return all_var
if __name__ == '__main__':
    meta_graph_path = "./save_restore_model/my_test.meta"
    ckpt_path = tf.train.latest_checkpoint('./save_restore_model')
    all_var = get_all_variables_name_from_ckpt(ckpt_path)
    print(all_var)

# 输出
{'v1': [3], 'v2': [5]}

从cpkt文件中拷贝模型的参数到自定义的变量中

def copy_var_from_ckpt(session, dst_var_name, dst_var,ckpt_path, meta_graph_path):
    #定义一个新的graph
    graph = tf.Graph()
    #将其设置为默认图:
    with graph.as_default():
        with tf.Session() as sess:
            #加载计算图
            saver = tf.train.import_meta_graph(meta_graph_path)
            #加载模型到session中关联的graph中,即将模型文件中的计算图加载到这里的graph中
            saver.restore(sess,ckpt_path)
            v_names = []
            #获取session所关联的图中可被训练的variable
            #使用tf.trainable_variables()获取variable时,只有在该函数前面定义的variable才会被获取到
            #在其后面定义不会被获取到,
            for v in tf.trainable_variables():
                v_names.append(v)
            if dst_var_name in v_names:
                #获取tensor
                tensor = graph.get_tensor_by_name(dst_var_name)
                #获取tensor的值,即网络中权值
                weight = sess.run(tensor)
                #拷贝权值,注意,需要使用dst_var所在的session
                #使用assign操作来拷贝dst_var是一个variable,weight是一个array
                session.run(dst_var.assign(weight))

4. Graph, MetaGrap的保存和加载

图创建完成之后,可以直接保存下来
使用write_graph/import_graph_def保存和加载。
另外还有MetaGraph导出和导入export_meta_graph/import_meta_graph方法。

5. 使用存储的模型做下一步工作

代码出自:A quick complete tutorial to save and restore Tensorflow models,by ANKIT SACHAN

新建一个model并保存

import tensorflow as tf

#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}

#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Create a saver object which will save all the variables
saver = tf.train.Saver()

#Run the operation by feeding input
print(sess.run(w4,feed_dict))
#Prints 24 which is sum of (w1+w2)*b1

#Now, save the graph
saver.save(sess, './my_test_model/test',global_step=1000)
# 输出
24.0

只传递新的数,直接使用原模型和参数,运行模型:


import tensorflow as tf

sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('./my_test_model/test-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./my_test_model'))


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print(sess.run(op_to_restore,feed_dict))
#This will print 60 which is calculated
#using new values of w1 and w2 and saved value of b1.
# 输出
60.0

在原有模型基础上添加新的参数


import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)

print sess.run(add_on_op,feed_dict)
#This will print 120.

在预训练好的模型输出新数据进行细调

参考下面的文章:
tensorflow保存和恢复模型的两种方法介绍第三部分:基于mnist写一个例子来说明如何使用save/restore方法保存和恢复模型


参考
https://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
https://zhuanlan.zhihu.com/p/31417693
https://maozezhong.github.io/2018/05/03/tensorflow/tensorflow保存模型及重载模型/
https://blog.csdn.net/u011106767/article/details/80052372
https://blog.csdn.net/m0_37924639/article/details/79340294

你可能感兴趣的:(Tensorflow)