Tensorflow Android端开发之——模型节点信息查看

查看tensorflow 冻结的网络模型(pb格式的文件)节点时可用以下的代码实现

拿ssd-mobilenet v1模型进行试验;

代码部分:

import tensorflow as tf
with tf.Session() as sess:
    with open('./ssd_mobilenet_v1_android_export.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read()) 
        print (graph_def)

前提是你切换到模型所在目录,否则请将上面模型所在位置写成绝对路径。

打印图的部分结果如下:

node {
  name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/strided_slice/stack"
  op: "Const"
  input: "^Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/switch_f"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
          dim {
            size: 1
          }
        }
        int_val: 0
      }
    }
  }
}
node {
  name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/strided_slice/stack_1"
  op: "Const"
  input: "^Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/switch_f"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
          dim {
            size: 1
          }
        }
        int_val: 1
      }
    }
  }
}
node {
  name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/strided_slice/stack_2"
  op: "Const"
  input: "^Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/switch_f"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
          dim {
            size: 1
          }
        }
        int_val: 1
      }
    }
  }
}
node {
  name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/sub/x"
  op: "Const"
  input: "^Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/switch_f"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 20
      }
    }
  }
}
node {
  name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/ExpandDims/dim"
  op: "Const"
  input: "^Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/switch_f"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 0
      }
    }
  }
}
node {
  name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/zeros/Const"
  op: "Const"
  input: "^Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/switch_f"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 0.0
      }
    }
  }
}
node {
  name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/concat/axis"
  op: "Const"
  input: "^Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/switch_f"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 0
      }
    }
  }
}
node {
  name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/pred_id"
  op: "Identity"
  input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/Greater_1"
  attr {
    key: "T"
    value {
      type: DT_BOOL
    }
  }
}
node {
  name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/Gather/Switch"
  op: "Switch"
  input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/MultiClassNonMaxSuppression/Gather_90/Gather_1"
  input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/pred_id"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/Gather"
  op: "Gather"
  input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/Gather/Switch:1"
  input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/range/_113__cf__113"
  attr {
    key: "Tindices"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "Tparams"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "validate_indices"
    value {
      b: true
    }
  }
}
node {
  name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/Merge"
  op: "Merge"
  input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/concat"
  input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/Gather"
  attr {
    key: "N"
    value {
      i: 2
    }
  }
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/TensorArrayWrite_2/TensorArrayWriteV3"
  op: "TensorArrayWriteV3"
  input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/TensorArrayWrite_2/TensorArrayWriteV3/Enter"
  input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/Identity"
  input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/Merge"
  input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/Identity_3"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/NextIteration_3"
  op: "NextIteration"
  input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/TensorArrayWrite_2/TensorArrayWriteV3"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/Shape/Switch"
  op: "Switch"
  input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/MultiClassNonMaxSuppression/Gather_90/Gather_1"
  input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/pred_id"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/Shape"
  op: "Shape"
  input: "Postprocessor/BatchMultiClassNonMaxSuppression/map/while/PadOrClipBoxList/cond_1/Shape/Switch"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "out_type"
    value {
      type: DT_INT32
    }
  }
}

打印的图中显示了node的名字,输入,op信息等。

其输出节点名为:

Tensorflow Android端开发之——模型节点信息查看_第1张图片

Tensorflow Android端开发之——模型节点信息查看_第2张图片

Tensorflow Android端开发之——模型节点信息查看_第3张图片

采用上述的方式可以在新的会话中重新加载本地的模型文件(pb),然后二进制解析后,输出可以看到结果。但是如果网络层结构十分复杂,那么这种显示方式就会比较难以阅读,不利于查找输出节点信息。

我们可重新加载模型文件,用tensorboard进行可视化处理,代码如下:

>>> from tensorflow.python.platform import gfile
>>> model='./ssd_mobilenet_v1_android_export.pb'
>>> graph = tf.get_default_graph()
>>> graph_def = graph.as_graph_def()
>>> graph_def.ParseFromString(gfile.FastGFile(model, 'rb').read())
29083865
>>> tf.import_graph_def(graph_def, name='graph')
>>> summaryWriter = tf.summary.FileWriter('./log/', graph)

会在当前目录下生成一个log文件夹,在log文件夹下生成一个事件信息:events.out.tfevents.1543161022.DESKTOP-K1M2IEH

接下来用tensorboard打开该事件文件,指令如下:(我事先将路径切到了log目录下了)

path\assets\log> tensorboard --logdir=./

在浏览器端查看ssd-mobilenet v1的图结构如下:

Tensorflow Android端开发之——模型节点信息查看_第4张图片

从图中可看到输出节点名为:num_detections; detection_scores;detection_boxes; thesor的shape分别为?;(?,20);(?,20,4)

从上图可以方便的看出输出节点名是什么。


后来我在ssd mobilenet v2训练得到的pb模型上也做了测试,查看图可得到输出节点名与上面的一样。num_detections; detection_scores;detection_boxes;

输入的节点名为:image_tensor;

Tensorflow Android端开发之——模型节点信息查看_第5张图片


我在tiny-yolov2.pb上查看了下节点,其输入节点名为:input;shape为(?x416x416x3)输出的节点名为output;shape为(?x13x13x425);其中425=5x(5+80)表示在coco上训练的80个类别;


只有知道网络的输出节点,才可以更好理解网络主结构,便于移动端的inference的工作。因为很多针对冻结图网络做的inference任务需要指定输出节点,然后对输出节点tensor进行解析。

你可能感兴趣的:(TensorFlow,深度学习/机器学习,Android,移动端(边缘设备)深度学习)