tensorflow checkpoint 转换为 pb 文件

参考链接:https://www.jianshu.com/p/091415b114e2

https://cloud.tencent.com/developer/ask/188650

由于arm nn官方提供的mnist-tf例程中提供的模型类型是prototxt或者pb文件,所以这里需要把tensorflow保存的ckpt文件转换成pb文件

tensorflow训练生成的ckpt文件包含4个,分别是

1.    checkpoint文件,记录了最新的检查点文件

2.    model.data文件,是saver.save(sess)保存的结果,记录了所有变量的值

3.    model.index文件,暂不明确,待查。恢复模型不必须用到

4.    model.meta文件,保存了计算图的结构,没有变量的值

转换方法

  1. 使用freeze_graph(见第一个参考链接,经过测试发现对于很小的模型lenet5可以成功,但是对于较大的模型,比如这里用到的一个400MB左右的网络,经过测试,会把16GB的内存消耗干净,转换失败=_=)

  2. 使用convert_variables_to_constants

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

sess = tf.Session()
saver = tf.train.import_meta_graph("meta文件目录")
saver.restore(sess, tf.train.latest_checkpoint("checkpoint文件所在目录"))
graph = tf.get_default_graph()

output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['输出tensor名字'])
with tf.gfile.FastGFile('pb文件保存目录', mode='wb') as f:
    f.write(output_graph_def.SerializeToString())

你可能感兴趣的:(tensorflow checkpoint 转换为 pb 文件)