tensorflow 一般的模型保存和读取方法

一、基于.ckpt文件的模型保存和载入方法

1.保存模型

使用saver = tf.train.Saver() 定义一个存储器对象, 然后使saver.save() 函数保存模型. saver 定义时可以指定需要保存的变量列表, 最大的检查点数量, 是否保存计算图等.
比如:

v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')

# 使用字典指定要保存的变量, 此时可以为每个变量重命名(保存的名字)
saver = tf.train.Saver({'v1': v1, 'v2': v2})

# 使用列表指定要保存的变量, 变量名字不变. 以下两种保存方式等价
saver = tf.train.Saver([v1, v2])
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})

# 保存相应变量到指定文件, 如果指定保存训练到哪一步的参数,指定global_step参数就可以了, 则实际保存的名称变为 model.ckpt-xxxx
saver.save(sess, "./model.ckpt", global_step)

如果你想简单使用的话,就下面这样用就可以了:


import tensorflow as tf  
import numpy as np  

W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w')  
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b')  

init = tf.initialize_all_variables()  
saver = tf.train.Saver()  
with tf.Session() as sess:  
        sess.run(init)  
        save_path = saver.save(sess,"save/model.ckpt")  

每保存一次就会产生这四个文件
checkpoint(更新)
model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta(计算图文件放在这里)

2.读取模型

(1)读取参数

只读取参数时,直接使用使用saver.restore()方法载入,它使用的前提是,已经确定好了模型的结构,可以用它来代替初始化。

# 首先定义一系列变量
...
# 载入变量的值
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "path/to/model.ckpt")

注意模型路径中应当以诸如 .ckpt 之类的来结尾, 即需要保证实际存在的文件是 model.ckpt.data-00000-of-00001 和 model.ckpt.index , 而指定的路径是 model.ckpt 即可.

如果想读取最近一次检查点的模型的参数,则:

ckpt = tf.train.get_checkpoint_state(log_dir)
if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)

如果要查看里面的变量名或者变量的值,则:

from tensorflow.python import pywrap_tensorflow as pt
reader = pt.NewCheckpointReader("path/to/model.ckpt")
# 获取 变量名: 形状
vars = reader.get_variable_to_shape_map()
for k in sorted(vars):
    print(k, vars[k])

# 获取 变量名: 类型
vars = reader.get_variable_to_dtype_map()
for k in sorted(vars):
    print(k, vars[k])

# 获取张量的值
value = reader.get_tensor("tensor_name")

如果只想简单的用一下,这样就可以了:

import tensorflow as tf  
import numpy as np  

W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w')  
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b')  

saver = tf.train.Saver()  
with tf.Session() as sess:  
        saver.restore(sess,"save/model.ckpt")  

(2)读取模型图

如果想不重新定义网络结构,利用tf.train.import_meta_graph将文件中所有的graph的所有节点保存到当前的default graph即可。

with tf.Session() as sess:
	new_saver = tf.train.import_meta_graph("model.ckpt.meta")

此时计算图就会加载到 sess 的默认计算图中, 这样我们就无需再次使用大量的脚本来定义计算图了. 实际上使用上面这两行代码即可完成计算图的读取. 注意可能我们获取的模型(meta文件)同时包含定义在CPU主机(host)和GPU等设备(device)上的, 上面的代码保留了原始的设备信息. 此时如果我们想同时加载模型权重, 那么如果当前没有指定设备的话就会出现错误, 因为tensorflow无法按照模型中的定义把某些变量(的值)放在指定的设备上. 那么有一个办法是增加一个参数清楚设备信息.

with tf.Session() as sess:
	new_saver = tf.train.import_meta_graph("model.ckpt.meta", clear_devices=True)

如果要获取计算图内的具体变量和操作,可以使用get_all_collection_keys()

sess.graph.get_all_collection_keys()
# 或
sess.graph.collections
# 或
tf.get_default_graph().get_all_collection_keys()
#或
tf.get_default_graph().et_operation_by_name().outputs[0]
# 输出
['summaries', 'train_op', 'trainable_variables', 'variables']

进一步通过get_collection()函数获取每个容器的内容

from pprint import pprint
pprint(sess.graph.get_collection("summaries"))
pprint(sess.graph.get_collection("variables"))
...

通过浏览 variables , trainable_variables , sumamries 和 train_op 中的变量我们可以初步推断计算图的结构和重要信息. 此外, 读取计算图后还可以直接使用 tf.summary.FileWriter() 保存计算图到 tensorboard, 从而获得更直观的计算图.
如果想简单点使用,就这样就可以了 :
保存模型:

### 定义模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')

w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
### 定义预测目标
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
# 创建saver
saver = tf.train.Saver(...variables...)
# 假如需要保存y,以便在预测时使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
    sess.run(train_op)
    if step % 1000 == 0:
        # 保存checkpoint, 同时也默认导出一个meta_graph
        # graph名为'my-model-{global_step}.meta'.
        saver.save(sess, 'my-model', global_step=step)

载入:

with tf.Session() as sess:
  new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
  new_saver.restore(sess, 'my-save-dir/my-model-10000')
  # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
  y = tf.get_collection('pred_network')[0]

  graph = tf.get_default_graph()

  # 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。
  input_x = graph.get_operation_by_name('input_x').outputs[0]
  keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]

  # 使用y进行预测  
  sess.run(y, feed_dict={input_x:....,  keep_prob:1.0})

注意两点:
一、 saver.restore()时填的文件名,因为在saver.save的时候,每个checkpoint会保存三个文件,如
my-model-10000.meta, my-model-10000.index, my-model-10000.data-00000-of-00001
在import_meta_graph时填的就是meta文件名,我们知道权值都保存在my-model-10000.data-00000-of-00001这个文件中,但是如果在restore方法中填这个文件名,就会报错,应该填的是前缀,这个前缀可以使用tf.train.latest_checkpoint(checkpoint_dir)这个方法获取。

二、模型的y中有用到placeholder,在sess.run()的时候肯定要feed对应的数据,因此还要根据具体placeholder的名字,从graph中使用get_operation_by_name方法获取

二、基于.pb文件的模型保存和载入方法

谷歌推荐的保存模型的方式是保存模型为 PB 文件,它具有语言独立性,可独立运行,封闭的序列化格式,任何语言都可以解析它,它允许其他语言和深度学习框架读取、继续训练和迁移 TensorFlow 的模型。
它的主要使用场景是实现创建模型与使用模型的解耦, 使得前向推导 inference的代码统一。
另外的好处是保存为 PB 文件时候,模型的变量都会变成固定的,导致模型的大小会大大减小,适合在手机端运行。
还有一个就是,真正离线测试使用的时候,pb格式的数据能够保证数据不会更新变动,就是不会进行反馈调节

1.保存模型

import tensorflow as tf

x = tf.placeholder(tf.float32,name="input")

a = tf.Variable(tf.constant(5.,shape=[1]),name="a")
b = tf.Variable(tf.constant(6.,shape=[1]),name="b")
c = tf.Variable(tf.constant(10.,shape=[1]),name="c")
d = tf.Variable(tf.constant(2.,shape=[1]),name="d")

tensor1 = tf.multiply(a,b,"mul")
tensor2 = tf.subtract(tensor1,c,"sub")
tensor3 = tf.div(tensor2,d,"div")
result = tf.add(tensor3,x,"add")

inial = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(inial)
    print(sess.run(a))
    print(result)
    result = sess.run(result,feed_dict={x:1.0})
    print(result)
    constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["add"])
    with tf.gfile.FastGFile("model.pb", mode='wb') as f:
        f.write(constant_graph.SerializeToString())

保存主要靠这几步:

constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["add"])
with tf.gfile.FastGFile("model.pb", mode='wb') as f:
    f.write(constant_graph.SerializeToString())

第一行代码的作用是将计算图中的变量转化为常量,并指定输出节点为“add”
第二行代码用来生成一个名为model.pb的文件(未指定路径的话,默认在该python代码的同路径下生成)
第三行代码的作用是将计算图写入该pb文件中

2.读取模型

with tf.gfile.FastGFile('retrain/output_graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
.........

参考链接:
https://www.cnblogs.com/flightless/p/10800476.html
https://blog.csdn.net/thriving_fcl/article/details/71423039?depth_1-utm_source=distribute.pc_relevant.none-task&utm_source=distribute.pc_relevant.none-task
https://www.jianshu.com/p/e5e36ffde809

你可能感兴趣的:(tensorflow 一般的模型保存和读取方法)