【TensorFlow系列】【三】冻结模型文件并做inference
这篇文章是一个非常简洁的例子,快速上手。
tensorflow框架.ckpt .pb模型节点tensor_name打印及ckpt模型转.pb模型
要点:
import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('./ade20k', "model.ckpt-27150")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
# print(reader.get_tensor(key)) #相应的值
import tensorflow as tf
import os
model_dir = './'
model_name = 'model.pb'
def create_graph():
with tf.gfile.FastGFile(os.path.join(
model_dir, model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
create_graph()
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
print(tensor_name,'\n')
from tensorflow.python.tools import inspect_checkpoint as chkp
import tensorflow as tf
saver = tf.train.import_meta_graph("./ade20k/model.ckpt-27150.meta", clear_devices=True)
#【敲黑板!】这里就是填写输出节点名称惹
output_nodes = ["xxx"]
with tf.Session(graph=tf.get_default_graph()) as sess:
input_graph_def = sess.graph.as_graph_def()
saver.restore(sess, "./ade20k/model.ckpt-27150")
output_graph_def = tf.graph_util.convert_variables_to_constants(sess,
input_graph_def,
output_nodes)
with open("frozen_model.pb", "wb") as f:
f.write(output_graph_def.SerializeToString())
tensorflow将训练好的模型freeze,即将权重固化到图里面,并使用该模型进行预测
这一篇更详细
要点:
for op in graph.get_operations():
print(op.name,op.values())
# prefix/Placeholder/inputs_placeholder
# ...
# prefix/Accuracy/predictions
#操作有:prefix/Placeholder/inputs_placeholder
#操作有:prefix/Accuracy/predictions
#为了预测,我们需要找到我们需要feed的tensor,那么就需要该tensor的名字
#注意prefix/Placeholder/inputs_placeholder仅仅是操作的名字,prefix/Placeholder/inputs_placeholder:0才是tensor的名字
x = graph.get_tensor_by_name('prefix/Placeholder/inputs_placeholder:0')
y = graph.get_tensor_by_name('prefix/Accuracy/predictions:0')
tensorflow深度学习实战笔记(二):把训练好的模型进行固化
要点:
export_inference_graph.py
和 freeze_graph.py
来固化模型TensorFlow: How to freeze a model and serve it with a python API
上面那篇文章其实总结了这篇文章
Saving, Freezing, Optimizing for inference, Restoring of tensorflow models
这篇文章初看好像有点多此一举的步骤,但还没细看
要点:
tf.train.write_graph
和 freeze_graph.freeze_graph
来固化模型optimize_for_inference_lib.optimize_for_inference
用途:
tf.contrib.layers.flatten or tf.reshape
这个也是用到了tf.train.write_graph
和freeze_graph.freeze_graph
来固化模型,但没有上面那个文章好。
查看tensorflow 模型文件的节点信息
查看tensorflow pb模型文件的节点信息
第一篇的一半来自于第二篇
固化模型的另一种方式:使用官方脚本 tensorflow/python/tools/freeze_graph.py
步骤
""" This script is a mixture of
https://github.com/YunYang1994/tensorflow-yolov3/blob/master/freeze_graph.py
and
https://gist.github.com/domluna/ed477cb5698c787f29c7d56fba381fed
(redirected from https://github.com/tensorflow/tensorflow/issues/10663)
"""
import os
import tensorflow as tf
from core.yolov3 import YOLOV3
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str()
pb_file = "./yolov3_coco.pb"
ckpt_file = "./checkpoint/yolov3_coco_demo.ckpt"
output_node_names = ["input/input_data", "pred_sbbox/concat_2", "pred_mbbox/concat_2", "pred_lbbox/concat_2"]
with tf.name_scope('input'):
input_data = tf.placeholder(dtype=tf.float32, name='input_data')
model = YOLOV3(input_data, trainable=False)
print(model.conv_sbbox, model.conv_mbbox, model.conv_lbbox)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
saver = tf.train.Saver()
saver.restore(sess, ckpt_file)
tf.train.write_graph(sess.graph.as_graph_def(), './freeze_model/', "graph.pb", as_text=False)
# Batch script
python C:\Python27\Lib\site-packages\tensorflow\python\tools\freeze_graph.py ^
--input_graph=.\freeze_model\graph.pb ^
--output_graph=.\freeze_model\frozen.pb ^
--input_checkpoint=.\checkpoint\yolov3_coco_demo.ckpt ^
--output_node_names=input/input_data,pred_sbbox/concat_2,pred_mbbox/concat_2,pred_lbbox/concat_2 ^
--input_binary=true