【TensorFlow】freeze_graph

0、网络模型的保存和读取

TensorFlow提供了一个非常简单的API,即tf.train.Saver类来保存和还原一个神经网络模型。

(1)下面代码给出了保存TensorFlow模型的方法:

import tensorflow as tf

# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
    sess.run(init_op)
    print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
    print("v2:", sess.run(v2))
    saver_path = saver.save(sess, "save/model.ckpt")  # 将模型保存到save/model.ckpt文件
    print("Model saved in file:", saver_path)

注:Saver方法已经发生了更改,现在是V2版本,tf.train.Saver(write_version=tf.train.SaverDef.V1)括号里加入该参数可继续使用V1,但会报warning,可忽略。若使用saver = tf.train.Saver()则默认使用当前的版本(V2),保存后在save这个文件夹中会出现4个文件,比V1版多出model.ckpt.data-00000-of-00001这个文件,这点感谢评论里那位朋友指出。至于这个文件的含义到目前我仍不是很清楚,也没查到具体资料,TensorFlow15年底开源到现在很多类啊函数都一直发生着变动,或被更新或被弃用,可能一些代码在当时是没问题的,但过了一大段时间后再跑可能就会报错,在此注明事件时间:2017.4.30


这段代码中,通过saver.save函数将TensorFlow模型保存到了save/model.ckpt文件中,这里代码中指定路径为"save/model.ckpt",也就是保存到了当前程序所在文件夹里面的save文件夹中。

(2)TensorFlow模型的保存

TensorFlow模型会保存在后缀为.ckpt的文件中。保存后在save这个文件夹中会出现3个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。

            (.meta文件保存了当前图结构;.index文件保存了当前参数名;.data文件保存了当前参数值。)

【TensorFlow】freeze_graph_第1张图片

  • checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是tf.train.Saver类自动生成且自动维护的。在 checkpoint文件中维护了由一个tf.train.Saver类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为CheckpointState Protocol Buffer.

  • model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构
    TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef Protocol Buffer定义的。MetaGraphDef 中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef 信息的文件默认以.meta为后缀名,文件model.ckpt.meta中存储的就是元图数据。

  • model.ckpt文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过SavedSlice Protocol Buffer定义的。SavedSlice类型中保存了变量的名称、当前片段的信息以及变量取值。TensorFlow提供了tf.train.NewCheckpointReader类来查看model.ckpt文件中保存的变量信息。如何使用tf.train.NewCheckpointReader类这里不做说明,自查。

(3)下面代码给出了加载TensorFlow模型的方法:

可以对比一下v1、v2的值是随机初始化的值还是和之前保存的值是一样的?

import tensorflow as tf

# 使用和保存模型代码中一样的方式来声明变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
    saver.restore(sess, "save/model.ckpt") # 即将固化到硬盘中的Session从保存路径再读取出来
    print("v1:", sess.run(v1)) # 打印v1、v2的值和之前的进行对比
    print("v2:", sess.run(v2))
    print("Model Restored")

运行结果:

v1: [[ 0.76705766  1.82217288]]
v2: [[-0.98012197  1.2369734   0.5797025 ]
 [ 2.50458145  0.81897354  0.07858191]]
Model Restored

这段加载模型的代码基本上和保存模型的代码是一样的。也是先定义了TensorFlow计算图上所有的运算,并声明了一个tf.train.Saver类。两段唯一的不同是,在加载模型的代码中没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。
也就是说使用TensorFlow完成了一次模型的保存和读取的操作。


一、使用freeze_graph.py将模型文件和权重数据整合在一起并去除无关的Op

tensorflow的Freezing,字面意思是冷冻,可理解为整合合并;就是将模型文件和权重文件整合合并为一个文件,主要用途是便于发布。

官方解释可参考:https://www.tensorflow.org/extend/tool_developers/#freezing


tensorflow在训练过程中,通常不会将权重数据保存的格式文件里(这里我理解是模型文件),反而是分开保存在一个叫checkpoint的检查点文件里,当初始化时,再通过模型文件里的变量Op节点来从checkoupoint文件读取数据并初始化变量。这种模型权重数据分开保存的情况,使得发布产品时不是那么方便,所以便有了freeze_graph.py脚本文件用来将这两文件整合合并成一个文件。

(1)freeze_graph.py是怎么做的呢?

        (a)它先加载模型文件

        (b)从checkpoint文件读取权重数据初始化到模型里的权重变量

        (c)将权重变量转换成权重常量 (因为常量能随模型一起保存在同一个文件里),

        (d)再通过指定的输出节点没用于输出推理的Op节点从图中剥离掉,再重新保存到指定的文件里(用write_graphdef或Saver)

(2)路径

文件目录:tensorflow/python/tools/free_graph.py

测试文件:tensorflow/python/tools/free_graph_test.py 这个测试文件很有学习价值

(3)参数

    总共有11个参数,一个个介绍下(必选: 表示必须有值;可选: 表示可以为空):

    1、input_graph:(必选)模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary来指定区分(见下面说明)

    2、input_saver:(可选)Saver解析器。保存模型和权限时,Saver也可以自身序列化保存,以便在加载时应用合适的版本。主要用于版本不兼容时使用。可以为空,为空时用当前版本的Saver。

    3、input_binary:(可选)配合input_graph用,为true时,input_graph为二进制,为false时,input_graph为文件。默认False

    4、input_checkpoint:(必选)检查点数据文件。训练时,给Saver用于保存权重、偏置等变量值。这时用于模型恢复变量值。

    5、output_node_names:(必选)输出节点的名字,有多个时用逗号分开。用于指定输出节点,将没有在输出线上的其它节点剔除。

    6、restore_op_name:(可选)从模型恢复节点的名字。升级版中已弃用。默认:save/restore_all

    7、filename_tensor_name:(可选)已弃用。默认:save/Const:0

    8、output_graph:(必选)用来保存整合后的模型输出文件。

    9、clear_devices:(可选),默认True。指定是否清除训练时节点指定的运算设备(如cpu、gpu、tpu。cpu是默认)

    10、initializer_nodes:(可选)默认空。权限加载后,可通过此参数来指定需要初始化的节点,用逗号分隔多个节点名字。

    11、variable_names_blacklist:(可先)默认空。变量黑名单,用于指定不用恢复值的变量,用逗号分隔多个变量名字。

(4)用法

    (a) 例:python tensorflow/python/tools/free_graph.py \
            --input_graph=some_graph_def.pb \
注意:这里的pb文件是用tf.train.write_graph方法保存的
         --input_checkpoint=model.ckpt.1001 \ 注意:这里若是r12以上的版本,只需给.data-00000....前面的文件名,   如:model.ckpt.1001.data-00000-of-00001,只需写model.ckpt.1001  
           --output_graph=/tmp/frozen_graph.pb

           --output_node_names=softmax

    (b)另外,如果模型文件是.meta格式的,也就是说用saver.Save方法和checkpoint一起生成的元模型文件,free_graph.py不适用,但可以改造下:

                1、copy free_graph.py为free_graph_meta.py

                2、修改free_graph.py,导入meta_graph:from tensorflow.python.framework import meta_graph

                3、将91行到97行换成:input_graph_def = meta_graph.read_meta_graph_file(input_graph).graph_def

这样改即可加载meta文件

二、将变量(偏执,权重等)固化到模型数据

        在tensorflow中,graph是训练的核心,当一个模型训练完成后,需要将模型保存下来,一个通常的操作是:

variables = tf.all_variables()
                saver = tf.train.Saver(variables)
                saver.save(sess, "data/data.ckpt")
tf.train.write_graph(sess.graph_def, 'graph', 'model.ph', False)

    将model保存在model.ph文件中

    然而使用的时候不仅要加载模型文件model.ph,还要加载保存的data.ckpt数据文件才能使用。这样保持了数据与模型的分离。

当我们把一个训练模型完整的训练好上线时候,我们期待的场景是:将一张图片喂进去,然后得出结果。 这时候再这样加载或许有些不必要,特别是在一些变量”不明”的时候特别麻烦.这时候一个比较好的方法就是将变量(偏执,权重等)固化到模型数据中。

(1)创建图

【TensorFlow】freeze_graph_第2张图片

(2)声明tensor


(3)固化保存

【TensorFlow】freeze_graph_第3张图片

固化操作中最重要的函数是:

tf.graph_util.convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist=None, variable_names_blacklist=None)
代码运行后控制台打印


这样在我们使用的时候就不要再进行data.ckpt的数据恢复。直接通过:

sess.graph.get_tensor_by_name()
就可以获取一个tensor。




你可能感兴趣的:(TensorFlow)