Tensorflow详解(七)——模型持久化

 

一 ,使用Saver类会将模型保存为 .ckpt格式,这样会保存模型中的全部信息;PB文件适用于保存部分信息,且可以被其他语言和深度学习框架读取和继续训练,适用于迁移训练

train.Saver类 ——用于保存和还原一个神经网络模型的API

1.模型保存加载

1.1保存加载全部变量

     步骤1.定义Saver类对象  

         saver=tf.train.Saver()

             2.模型保存

                  saver.save(sess,'保存路径+文件名')

               说明:虽然指定了一个文件路径,但这个文件目录下会出现4个文件;(旧版本只会生成model.ckpt一个文件)

checkpoint 保存了一个目录下所有的模型文件列表
model.ckpt.data 保存变量的取值,二进制文件
model.ckpt.index 保存了每个变量的名称,二进制文件
model.ckpt.meta 保存了计算图的结构

            3.模型加载

               与保存模型的代码很相近,但省略了初始化全部变量的过程,使用restore()函数完成变量赋值过程;

               3.1定义图上所有的运算,变量名需要与模型存储的变量名一致

               3.2在会话中完成模型加载

                  saver.restore(sess,'保存路径+文件名')

1.2保存加载部分变量——方便于修改网络

     步骤1.定义Saver类对象 ,同时提供一个列表来指定需要保存或加载的变量 

         saver=tf.train.Saver([变量列表])

             2.模型保存

                  saver.save()

            3.模型加载

                  saver.restore()

1.3保存或加载时给变量重新命名——通常用于代码中需要加载的变量和模型中保存的变量有不同名称

     步骤1.定义Saver类对象 ,以字典的方式将模型保存时的变量名和需要加载的变量名联系起来 

         saver=tf.train.Saver({键:值})

             2.模型保存

                  saver.save()

            3.模型加载

                  saver.restore()

 2.计算图的加载              

     步骤1.通过 .meta文件直接加载持久化的计算图

         meta_graph=tf.train.import_meta_graph() #返回一个Saver类

            2.模型加载

                  meta_graph.restore()

           3.获取默认计算图上指定节点处的张量

                 tf.get_default_graph().get_tensor_by_name()  

   3.读取变量取值

      说明:读取持久化的变量的取值时,.data/.index两个文件缺一不可

      步骤1.读取文件

         reader=tf.train.NewCheckpointReader('路径+文件名.ckpt')
       #其他读取方式
       from tensorflow.python import pywrap_tensorflow
       reader = pywrap_tensorflow.NewCheckpointReader(file_name)

                   2.对象的常用方法

        1.variable=reader.get_tensor('变量名')               #获取单个变量的值

        2.all_variables=reader.get_variable_to_shape_map()  #按字典格式返回 键:变量名,值:变量形状

        3.all_variables=reader.get_variable_to_dtype_map()  #按字典格式返回 键:变量名,值:变量类型

二,PB文件——模型的变量都会变成固定的常量,保证模型会被大大的减少,适用于一些类似手机的移动端运行

2.1模型保存

      步骤1.得到计算图中的节点信息GraphDef,可以通过as_graph_def()函数完成

                 graph_def = tf.get_default_graph().as_graph_def()

             2.1导入graph_util模块

         from tensorflow.python.framework import graph_util

              2.2将计算图的变量以及取值通过常量方式保存

         graph_util.convert_variables_to_constants()

              函数解析:convert_variables_to_constants()

              3.将导出的模型存入.pb文件

                  with tf.gfile.GFile("/home/jiangziyang/model/model.pb", "wb") as f:      

                  # SerializeToString()函数用于将获取到的数据取出存到一个string对象中,
                  # 然后再以二进制流的方式将其写入到磁盘文件中
                           f.write(output_graph_def.SerializeToString())    

2.2模型获取

     步骤1.导入gfile模块   from tensorflow.python.platform import gfile

            2.使用FsatGFile类的构造函数返回一个FastGFile类

               with gfile.FastGFile("/home/jiangziyang/model/model.pb", 'rb') as f:

            3.读取并解析

                graph_def = tf.GraphDef()
                # 使用FastGFile类的read()函数读取保存的模型文件,并以字符串形式
                # 返回文件的内容,之后通过ParseFromString()函数解析文件的内容
                graph_def.ParseFromString(f.read())

             4.使用import_graph_def()函数将graph_def中保存的计算图加载到当前图中

               result = tf.import_graph_def()

 

三,函数解析

    1. convert_variables_to_constants()函数表示用相同值的常量替换计算图中所有变量

       原型:convert_variables_to_constants(sess,input_graph_def,output_node_names,
                                                                          variable_names_whitelist, variable_names_blacklist)

      参数:sess:会话

                input_graph_def:具有节点的GraphDef对象,

                output_node_names:要保存的计算图中的计算节点的名称,通常为字符串列表的形式

                variable_names_whitelist:要转换为常量的变量名称集合(默认情况下,所有变量都将被转换),
                variable_names_blacklist:要省略转换为常量的变量名的集合

     2. GraphDef() 计算图中的节点信息  

     3. import_graph_def()函数将graph_def中保存的计算图加载到当前图中

         原型:import_graph_def(graph_def,input_map,return_elements,name,op_dict,producer_op_list)

         参数:graph_def:传递进来的GraphDef

                    return_elements: 指定要将graph_def中的哪个节点作为函数返回的结果

               
    
   
   

  

                 

 

 

     

 

 
 

你可能感兴趣的:(tensorflow)