拿ssd-mobilenet v1模型进行试验;
代码部分:
import tensorflow as tf
with tf.Session() as sess:
with open('./ssd_mobilenet_v1_android_export.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
print (graph_def)
前提是你切换到模型所在目录,否则请将上面模型所在位置写成绝对路径。
打印图的部分结果如下:
node {
name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/strided_slice/stack"
op: "Const"
input: "^Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/switch_f"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 1
}
}
int_val: 0
}
}
}
}
node {
name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/strided_slice/stack_1"
op: "Const"
input: "^Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/switch_f"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 1
}
}
int_val: 1
}
}
}
}
node {
name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/strided_slice/stack_2"
op: "Const"
input: "^Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/switch_f"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 1
}
}
int_val: 1
}
}
}
}
node {
name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/sub/x"
op: "Const"
input: "^Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/switch_f"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 20
}
}
}
}
node {
name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/ExpandDims/dim"
op: "Const"
input: "^Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/switch_f"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 0
}
}
}
}
node {
name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/zeros/Const"
op: "Const"
input: "^Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/switch_f"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 0.0
}
}
}
}
node {
name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/concat/axis"
op: "Const"
input: "^Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/switch_f"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 0
}
}
}
}
node {
name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/pred_id"
op: "Identity"
input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/Greater_1"
attr {
key: "T"
value {
type: DT_BOOL
}
}
}
node {
name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/Gather/Switch"
op: "Switch"
input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/MultiClassNonMaxSuppression/Gather_90/Gather_1"
input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/pred_id"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/Gather"
op: "Gather"
input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/Gather/Switch:1"
input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/range/_113__cf__113"
attr {
key: "Tindices"
value {
type: DT_INT32
}
}
attr {
key: "Tparams"
value {
type: DT_FLOAT
}
}
attr {
key: "validate_indices"
value {
b: true
}
}
}
node {
name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/Merge"
op: "Merge"
input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/concat"
input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/Gather"
attr {
key: "N"
value {
i: 2
}
}
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/TensorArrayWrite_2/TensorArrayWriteV3"
op: "TensorArrayWriteV3"
input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/TensorArrayWrite_2/TensorArrayWriteV3/Enter"
input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/Identity"
input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/Merge"
input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/Identity_3"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/NextIteration_3"
op: "NextIteration"
input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/TensorArrayWrite_2/TensorArrayWriteV3"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/Shape/Switch"
op: "Switch"
input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/MultiClassNonMaxSuppression/Gather_90/Gather_1"
input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/pred_id"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/Shape"
op: "Shape"
input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/Shape/Switch"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "out_type"
value {
type: DT_INT32
}
}
}
打印的图中显示了node的名字,输入,op信息等。
其输出节点名为:
采用上述的方式可以在新的会话中重新加载本地的模型文件(pb),然后二进制解析后,输出可以看到结果。但是如果网络层结构十分复杂,那么这种显示方式就会比较难以阅读,不利于查找输出节点信息。
>>> from tensorflow.python.platform import gfile
>>> model='./ssd_mobilenet_v1_android_export.pb'
>>> graph = tf.get_default_graph()
>>> graph_def = graph.as_graph_def()
>>> graph_def.ParseFromString(gfile.FastGFile(model, 'rb').read())
29083865
>>> tf.import_graph_def(graph_def, name='graph')
>>> summaryWriter = tf.summary.FileWriter('./log/', graph)
会在当前目录下生成一个log文件夹,在log文件夹下生成一个事件信息:events.out.tfevents.1543161022.DESKTOP-K1M2IEH
接下来用tensorboard打开该事件文件,指令如下:(我事先将路径切到了log目录下了)
path\assets\log> tensorboard --logdir=./
在浏览器端查看ssd-mobilenet v1的图结构如下:
从图中可看到输出节点名为:num_detections; detection_scores;detection_boxes; thesor的shape分别为?;(?,20);(?,20,4)
从上图可以方便的看出输出节点名是什么。
后来我在ssd mobilenet v2训练得到的pb模型上也做了测试,查看图可得到输出节点名与上面的一样。num_detections; detection_scores;detection_boxes;
输入的节点名为:image_tensor;
我在tiny-yolov2.pb上查看了下节点,其输入节点名为:input;shape为(?x416x416x3)输出的节点名为output;shape为(?x13x13x425);其中425=5x(5+80)表示在coco上训练的80个类别;
只有知道网络的输出节点,才可以更好理解网络主结构,便于移动端的inference的工作。因为很多针对冻结图网络做的inference任务需要指定输出节点,然后对输出节点tensor进行解析。