Tensorflow(四)保存和恢复模型,从原理到代码

1.TensorFlow模型文件简介

1.1 两类模型文件

TensorFlow提供了一个非常简单的API,即tf.train.Saver类来保存和还原一个神经网络型。但是,Saver方法已经发生了更改,现在是V2版本,tf.train.Saver(write_version=tf.train.SaverDef.V1)括号里加入该参数可继续使用V1,但会报warning,可忽略,这样的结果是会生成cpkt文件。若使用saver = tf.train.Saver()则默认使用当前的版本(V2),保存后在save这个文件夹中会出现4个文件。
图中fine_tune文件夹下生产的四个模型文件是新版本的,而vgg_16.ckpt则是旧版本的函数生成的,目前Github中有一部分模型文件就是只有一个文件。


image.png

几个文件的信息,仅供参考:
checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是tf.train.Saver类自动生成且自动维护的。在 checkpoint文件中维护了由一个tf.train.Saver类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为CheckpointState Protocol Buffer.

model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构
TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef Protocol Buffer定义的。MetaGraphDef 中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef 信息的文件默认以.meta为后缀名,文件model.ckpt.meta中存储的就是元图数据。

model.ckpt文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过SavedSlice Protocol Buffer定义的。SavedSlice类型中保存了变量的名称、当前片段的信息以及变量取值。TensorFlow提供了tf.train.NewCheckpointReader类来查看model.ckpt文件中保存的变量信息。如何使用tf.train.NewCheckpointReader类这里不做说明,自查。

2.Saver保存与恢复

2.1 Saver保存

import tensorflow as tf
import numpy as np
 
#定义W和b
W = tf.Variable([[1,2,3],[3,5,6]],dtype = tf.float32,name = 'weight')
b = tf.Variable([1,2,3],dtype = tf.float32,name = 'biases')
#注:初始化变量Variable
init = tf.global_variables_initializer()
 
 
#建立tf.train.Saver() 来保存, 提取变量。
#建立my_net文件夹,保存变量
saver =  tf.train.Saver()
 
sess = tf.Session()
sess.run(init)
#保存变量到路径my_net
save_path = saver.save(sess,"my_net/save_net.ckpt")#保存格式为ckpt
 
#输出保存的变量
print("save path:",save_path)

2.2 Saver读取

import tensorflow as tf
import numpy as np
 
 
#建立W,b的空容器
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
 
#不需要初始化变量
 
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "my_net/save_net.ckpt")
    print("weights:", sess.run(W))
    print("biases:", sess.run(b))

3.【Tensorflow】tf.Graph()函数

参考资料

[1] 【Tensorflow】tf.Graph()函数
[2] TensorFlow Python API解析:图的核心数据结构

[3] Tensorflow框架实现中的“三”种图 非常重要,非常清晰
[4] Tensorflow模型持久化与恢复非常重要,非常清晰
[5] TensorFlow-网络模型的保存和读取
[6] tensorflow对自己的数据进行训练(选择性的恢复权值)(26)---《深度学习》 非常好,参考
[7] tensorflow 1.0 学习:模型的保存与恢复(Saver)
[8] TensorFlow中tf.train.Saver类说明

你可能感兴趣的:(Tensorflow(四)保存和恢复模型,从原理到代码)