一 ,使用Saver类会将模型保存为 .ckpt格式,这样会保存模型中的全部信息;PB文件适用于保存部分信息,且可以被其他语言和深度学习框架读取和继续训练,适用于迁移训练
train.Saver类 ——用于保存和还原一个神经网络模型的API
1.模型保存加载
1.1保存加载全部变量
步骤1.定义Saver类对象
saver=tf.train.Saver()
2.模型保存
saver.save(sess,'保存路径+文件名')
说明:虽然指定了一个文件路径,但这个文件目录下会出现4个文件;(旧版本只会生成model.ckpt一个文件)
checkpoint | 保存了一个目录下所有的模型文件列表 |
model.ckpt.data | 保存变量的取值,二进制文件 |
model.ckpt.index | 保存了每个变量的名称,二进制文件 |
model.ckpt.meta | 保存了计算图的结构 |
3.模型加载
与保存模型的代码很相近,但省略了初始化全部变量的过程,使用restore()函数完成变量赋值过程;
3.1定义图上所有的运算,变量名需要与模型存储的变量名一致
3.2在会话中完成模型加载
saver.restore(sess,'保存路径+文件名')
1.2保存加载部分变量——方便于修改网络
步骤1.定义Saver类对象 ,同时提供一个列表来指定需要保存或加载的变量
saver=tf.train.Saver([变量列表])
2.模型保存
saver.save()
3.模型加载
saver.restore()
1.3保存或加载时给变量重新命名——通常用于代码中需要加载的变量和模型中保存的变量有不同名称
步骤1.定义Saver类对象 ,以字典的方式将模型保存时的变量名和需要加载的变量名联系起来
saver=tf.train.Saver({键:值})
2.模型保存
saver.save()
3.模型加载
saver.restore()
2.计算图的加载
步骤1.通过 .meta文件直接加载持久化的计算图
meta_graph=tf.train.import_meta_graph() #返回一个Saver类
2.模型加载
meta_graph.restore()
3.获取默认计算图上指定节点处的张量
tf.get_default_graph().get_tensor_by_name()
3.读取变量取值
说明:读取持久化的变量的取值时,.data/.index两个文件缺一不可
步骤1.读取文件
reader=tf.train.NewCheckpointReader('路径+文件名.ckpt')
#其他读取方式
from tensorflow.python import pywrap_tensorflow
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
2.对象的常用方法
1.variable=reader.get_tensor('变量名') #获取单个变量的值
2.all_variables=reader.get_variable_to_shape_map() #按字典格式返回 键:变量名,值:变量形状
3.all_variables=reader.get_variable_to_dtype_map() #按字典格式返回 键:变量名,值:变量类型
二,PB文件——模型的变量都会变成固定的常量,保证模型会被大大的减少,适用于一些类似手机的移动端运行
2.1模型保存
步骤1.得到计算图中的节点信息GraphDef,可以通过as_graph_def()函数完成
graph_def = tf.get_default_graph().as_graph_def()
2.1导入graph_util模块
from tensorflow.python.framework import graph_util
2.2将计算图的变量以及取值通过常量方式保存
graph_util.convert_variables_to_constants()
函数解析:convert_variables_to_constants()
3.将导出的模型存入.pb文件
with tf.gfile.GFile("/home/jiangziyang/model/model.pb", "wb") as f:
# SerializeToString()函数用于将获取到的数据取出存到一个string对象中,
# 然后再以二进制流的方式将其写入到磁盘文件中
f.write(output_graph_def.SerializeToString())
2.2模型获取
步骤1.导入gfile模块 from tensorflow.python.platform import gfile
2.使用FsatGFile类的构造函数返回一个FastGFile类
with gfile.FastGFile("/home/jiangziyang/model/model.pb", 'rb') as f:
3.读取并解析
graph_def = tf.GraphDef()
# 使用FastGFile类的read()函数读取保存的模型文件,并以字符串形式
# 返回文件的内容,之后通过ParseFromString()函数解析文件的内容
graph_def.ParseFromString(f.read())
4.使用import_graph_def()函数将graph_def中保存的计算图加载到当前图中
result = tf.import_graph_def()
三,函数解析
1. convert_variables_to_constants()函数表示用相同值的常量替换计算图中所有变量
原型:convert_variables_to_constants(sess,input_graph_def,output_node_names,
variable_names_whitelist, variable_names_blacklist)
参数:sess:会话
input_graph_def:具有节点的GraphDef对象,
output_node_names:要保存的计算图中的计算节点的名称,通常为字符串列表的形式
variable_names_whitelist:要转换为常量的变量名称集合(默认情况下,所有变量都将被转换),
variable_names_blacklist:要省略转换为常量的变量名的集合
2. GraphDef() 计算图中的节点信息
3. import_graph_def()函数将graph_def中保存的计算图加载到当前图中
原型:import_graph_def(graph_def,input_map,return_elements,name,op_dict,producer_op_list)
参数:graph_def:传递进来的GraphDef
return_elements: 指定要将graph_def中的哪个节点作为函数返回的结果