转换模型首先要知道的是从哪个节点输出,如果没有源代码是很难清楚节点信息。
import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('./ade20k', "model.ckpt-27150")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
# print(reader.get_tensor(key)) #相应的值
import tensorflow as tf
import os
model_dir = './'
model_name = 'model.pb'
def create_graph():
with tf.gfile.FastGFile(os.path.join(
model_dir, model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
create_graph()
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
print(tensor_name,'\n')
from tensorflow.python.tools import inspect_checkpoint as chkp
import tensorflow as tf
saver = tf.train.import_meta_graph("./ade20k/model.ckpt-27150.meta", clear_devices=True)
#【敲黑板!】这里就是填写输出节点名称惹
output_nodes = ["xxx"]
with tf.Session(graph=tf.get_default_graph()) as sess:
input_graph_def = sess.graph.as_graph_def()
saver.restore(sess, "./ade20k/model.ckpt-27150")
output_graph_def = tf.graph_util.convert_variables_to_constants(sess,
input_graph_def,
output_nodes)
with open("frozen_model.pb", "wb") as f:
f.write(output_graph_def.SerializeToString())
import tensorflow as tf
from tensorflow.python.framework import graph_util
checkpoint = "model.ckpt-xxx"
graph_file = "xxx.pb"
def return_ops(candidate):
ops = []
if isinstance(candidate, (list, tuple)):
for x in candidate:
ops += return_ops(x)
else:
ops.append(candidate.op)
return ops
def dump_graph():
with tf.Graph().as_default():
inputs = setup_input(dtype=tf.float32,
shape=[None, 224, 224, 3],
name='graph_input')
outputs = model_inference(inputs, 1000)
model_info = gen_info(inputs, outputs)
print(model_info)
saver = tf.train.Saver(tf.global_variables())
dest_node = return_ops(outputs)
with tf.Session() as sess:
saver.restore(sess, checkpoint)
cur_graphdef = sess.graph.as_graph_def()
output_graphdef = graph_util.convert_variables_to_constants(
sess, cur_graphdef, [n.name for n in dest_node])
with tf.gfile.GFile(graph_file, 'wb') as gf:
gf.write(output_graphdef.SerializeToString())
with open(graph_file + '.info', 'w') as info_f:
info_f.write(model_info)
def setup_input(dtype, shape, name=None):
p_node = tf.Placeholder(dtype=dtype, shape=shape, name=name)
return p_node
def gen_info(inp, o):
info_text = '[input tensor]: {0}\n[input shape]: {1}\n'.format(
inp.name, inp.get_shape())
print("outp", o)
info_text += '[output tensor]: {0}\n[output shape]: {1}\n'.format(
o.name, o.get_shape())
return info_text
def model_inference(images, num_classes):
with tf.variable_scope('xxx'):
logits = tf.xxx
return logits
if __name__ == "__main__":
dump_graph()
print('dump finish!')