tensorflow中模型保存加载操作的学习体会

姓名:乐仁华 学号:16140220023

【嵌牛导读】:本文简述了学习tensorflow中保存加载模型的体会及总结

【嵌牛鼻子】:tensorflow,检查点文件

【嵌牛提问】:tensorflow中保存加载模型有什么方法?

【嵌牛正文】:

先简单提一下模型参数保存及加载的函数

tf.train.Saver()

tf.train.Saver()是tensorflow中加载,保存模型参数的一个类
使用方法:

#实例化类
saver = tf.train.Saver()

#调用save方法保存参数,ckpt为保存的模型参数名,如'run_dir/model.ckpt',
#其中run_dir表示模型所在的文件夹
#step表示迭代步数
saver.save(sess,ckpt,gloabal_step=step)

#如果需要加载参数
restorer = tf.train.Saver()
#这里的ckpt与保存过程的ckpt一致
restorer.restore(sess,ckpt)

更多详细的用法可以看官方文档

检查点文件格式

tensorflow中模型保存加载操作的学习体会_第1张图片

保存的检查点文件如上图所示,
.meta文件保存了当前图结构
.index文件保存了当前参数名
.data文件保存了当前参数值
每调用一次save方法会产生新的文件

获取最新保存的检查点文件

#假设check_path为保存这些检查点文件的文件夹
#tf.train.get_checkpoint_state(check_point)表示查看check_point文件夹下是否有检查点文件
ckpt = tf.train.get_checkpoint_state(check_point)
#获取最新保存的模型检查点文件
ckpt.model_checkpoint_path

还有其他的方法,不过我没怎么用过,大家可以自己上网查查

查看检查点文件中的各tensor

有时我们会需要查看检查点文件中各变量,这时可以调用tensorflow中的方法查看

from tensorflow.python import pywrap_tensorflow

# 从检查点文件中读取数据
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# 显示变量名及其值
for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key))

保存及加载图结构

我们知道tensorflow是以图表示计算过程的,各节点操作都在图上,自然也就有保存图结构的方法

tf.train.write_graph()

具体参数看这:

tf.train.write_graph(graph_or_graph_def, logdir, name, as_text=True)
# Writes a graph proto to a file.
#      graph_or_graph_def: A `Graph` or a `GraphDef` protocol buffer.
#      logdir: Directory where to write the graph. This can refer to remote
#        filesystems, such as Google Cloud Storage (GCS).
#      name: Filename for the graph.
#      as_text: If `True`, writes the graph as an ASCII proto.
    
#    Returns:
#      The path of the output proto file.
(从内置文档摘来的,相信大家都看得懂^_^)

如果要加载的话就用这个:

tf.train.import_meta_graph(meta_graph_or_file, clear_devices=False, import_scope=None)
#参数如下
#meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
#       the path) containing a `MetaGraphDef`.
#    clear_devices: Whether or not to clear the device field for an `Operation`
#        or `Tensor` during import.
#     import_scope: Optional `string`. Name scope to add. Only used when
#        initializing from protocol buffer.
#      **kwargs: Optional keyed arguments.
    
#    Returns:
#      A saver constructed from `saver_def` in `MetaGraphDef` or None.
    
#      A None value is returned if no variables exist in the `MetaGraphDef`
(还是相信大家^_^)

细心的读者可能发现了前头提到的检查点文件里面也有个保存结构的文件,那这两者有啥区别吗,说实话我也不清楚。。。。。

你可能感兴趣的:(tensorflow中模型保存加载操作的学习体会)