如何查看ckpt的输入输出节点:
def getinout(input_checkpoint):
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
with tf.Session() as sess:
file=open('./nodes.txt','a+')
for n in tf.get_default_graph().as_graph_def().node:
file.write(n.name + '\n')
file.close()
n.name就是各个节点的名称,怎么找到呢?例子:
img_in---------------------》这个是输入节点
hh/conv1_1/conv2d/w/Initializer/truncated_normal/shape
hh/conv1_1/conv2d/w/Initializer/truncated_normal/mean
hh/res3_1/dw_bn/moving_mean/read
hh/fc4/w/Regularizer/l2_regularizer/L2Loss
hh/fc4/w/Regularizer/l2_regularizer
hh/fc4/fc4/Conv2D ----------》这个就是输出节点
Reshape/shape -----------》这个是神经网络输出后的操作步骤
Reshape
Reshape_1/shape
Reshape_1
Const
Mul
Mul_1
huber_loss/Sub
输入输出节点完全是跟你的网络定义顺序密切关联的,仔细核对就能找到。
将ckpt转换为pb:
def convt_ckpt_pb(checkpoint, metapath, output_graph):
output_node_names = "one,two"
# create a session
sess = tf.Session()
# import best model
saver = tf.train.import_meta_graph(metapath + '.meta', clear_devices=True)
saver.restore(sess, input_metapath) # variables
# get graph definition
gd = sess.graph.as_graph_def()
# generate protobuf
converted_graph_def = graph_util.convert_variables_to_constants(sess, gd, output_node_names.split(","))
tf.train.write_graph(converted_graph_def, input_checkpoint + '\\', 'frozen_model.pb', as_text=False)
tf.train.write_graph(converted_graph_def, input_checkpoint + '\\', 'frozen_model.pbtxt', as_text=True)
查看 输入输出节点:
cd c:/tensorflow131/source/tensorflow
bazel build tools/graph_transforms:summarize_graph
build完之后:
cd C:/_bazel_****/olqihtlt/execroot/org_tensorflow/bazel-out/x64_windows-opt/bin/tensorflow/tools/graph_transforms
summarize_graph --in_graph=D:/project/convert_model/frozen_model.pb
将pb模型转换为uff问题处理:
raise UffException("Transpose permutation has op " + str(tf_permutation_node.op) + ", expected Const. Only constant permuations are supported in UFF.")
uff.model.exceptions.UffException: Transpose permutation has op ConcatV2, expected Const. Only constant permuations are supported in UFF.怎么解决呢?
我们打印模型的输出节点.(在转换uff的时候打印)
import graphsurgeon as gs
import tensorflow as tf
import uff
if __name__ == "__main__":
# USER DEFINED VALUES
output_nodes = [“one","two"]
input_node = ["in","_in"]
pb_file = "./frozen_model.pb"
uff_file = "./frozen_model.uff"
# END USER DEFINED VALUES
# read tensorflow graph
# NOTE: Make sure to freeze and optimize (remove training nodes, etc.)
dynamic_graph = gs.DynamicGraph(pb_file)
nodes=[n.name for n in dynamic_graph.as_graph_def().node]
print(nodes) # 在这一行打印
ns={}
for node in nodes:
# replace LeakyRelu with default TRT plugin LReLU_TRT
if "LeakyRelu" in node:
ns[node]=gs.create_plugin_node(name=node,op="LReLU_TRT", negSlope=0.1)
# replace Maximum with L2Norm_Helper_TRT max operation (CUDA's fmaxf)
# if node == "orientation/l2_normalize/Maximum":
if node == "embeddings/Maximum":
ns[node]=gs.create_plugin_node(name=node,op="L2Norm_Helper_TRT",op_type=0,eps=1e-12)
# replace Rsqrt with L2Norm_Helper_TRT max operation (CUDA's rsqrtf)
if node == "embeddings/Rsqrt":
ns[node]=gs.create_plugin_node(name=node,op="L2Norm_Helper_TRT",op_type=1)
dynamic_graph.collapse_namespaces(ns)
# write UFF to file
uff_model = uff.from_tensorflow(dynamic_graph.as_graph_def(), output_nodes=output_nodes,
output_filename=uff_file, text=False)
发现tf.layers.dense中会用到Transpose。那为什么这里会用到呢?因为,在dense的输入中,其shape类似于(?,1,1,128)。这才是问题的核心。
为了使得tf.layers.dense不调用Transpose,我们就需要把shape搞成(?,128)就行了。所以对于,tf.layers.dense的输入,在前面需要做一下tf.reshape(input, (-1,128))处理。
重新训练得到模型,按照上面的方法,将ckpt转为pb,然后将pb转成uff就可以解决问题了。
关注他,获取更多干货