TensorFlow TFLite 移动端(安卓)部署中 freeze 的细节:export_tflite_ssd_graph()

文章目录

    • 写在前面
    • tf 版本
    • main 代码 `./export_tflite_ssd_graph.py`
      • 先用自然语言描述下 main 代码
      • main 调用的 `pipeline_pb2.py`
      • 看下 main 代码
    • 看下 `export_tflite_ssd_graph_lib.py`
      • export_tflite_graph()

写在前面

本篇是 详解 TensorFlow TFLite 移动端(安卓)部署物体检测 demo(2)——量化模型 的辅助篇,重点讨论:

  • freeze 过程中发生了什么?其实就是把训练得到的参数/变量进行常量化/固化,这一点会贯穿本篇;
  • 解释为什么 ssd_mobilenet_v2 必须用 tf1 环境;
  • 解释 freeze 的下一步 convert 中, input_arrays 和 output_arrays 是怎么来的。

涉及到的主要代码和文件:
不特意提及的均在 /models/research/object_detection 下。

  • main 代码,./export_tflite_ssd_graph.py
  • main 直接调用的 lib,./export_tflite_ssd_graph_lib.py
  • main 直接调用的 pipeline,./protos/pipeline_pb2.py,由对应的 .proto 文件编译生成。
  • lib 调用的 exporter, ./exporter.py
  • lib 调用的 graph rewriter builder,./builders/graph_rewriter_builder.py
  • lib 调用的 model builder,./builders/model_builder.py
  • lib 调用的 post processing builder,./builders/post_processing_builder.py
  • lib 调用的 box list,./core/box_list.py
  • lib 调用的 tf version,./utils/tf_version.py

tf 版本

其中,我想打破顺序先说一下的,是最后一个 utils 代码 tf_version,用来判断你所使用的 tf 版本是 1 还是 2,其内容非常简单,如下:

"""Functions to check TensorFlow Version."""
from tensorflow.python import tf2  # pylint: disable=import-outside-toplevel

def is_tf1():
  """Whether current TensorFlow Version is 1.X."""
  return not tf2.enabled()
  
def is_tf2():
  """Whether current TensorFlow Version is 2.X."""
  return tf2.enabled()

因为 tf 1 和 2 之间有些不同之处,所以 models 项目有根据 tf 版本做适配。适配整体分两种情况:

  • 导入的 module 会有些许不同,如 lib 文件中有:
if tf_version.is_tf1():
  from tensorflow.tools.graph_transforms import TransformGraph  # pylint: disable=g-import-not-at-top
  • 处理流程会稍有不同,甚至根据版本来直接停止处理,报错提醒,如 model_buider.py 中的 _check_feature_extractor_exists 函数会根据 tf 版本的不同来判断当前 tf 版本是否支持 feature extractor

main 代码 ./export_tflite_ssd_graph.py

先用自然语言描述下 main 代码

  • 代码的输入:
    (1)人为指定 3 个:输出 graph 的存储路径、pipeline 路径、要被 frozen 的 ckpt 路径;
    (2)已默认的 6 个:检测输出框的最大数量(10)、单框检测类别数的最大数量(1)、每个类别参与 NMS 的总物体框最大数量(100)、增加后处理节点(True)、不使用常规 NMS 替代 fast NMS(False)、不修改 pipeline config(空)。
  • 代码的输出:适配 tflite 的 frozen graph。一共在指定输出路径下输出两个文件,分别是 tflite_graph.pbtxttflite_graph.pb

而 main 代码输出 graph 的输入输出直接决定了 freeze 之后 convert 到 tflite 时 节点名称,输入输出名称是本篇中相关代码中写好的,也就是说如果你改了相关代码中的节点名,在 convert 那一步就需要也跟着改就好了:

  • graph 的输入:'normalized_input_image_tensor': a float32 tensor of shape [1, height, width, 3] containing the normalized input image. 对于浮点型 model,范围为 [-1, 1);对于量化后 model,范围为 [0, 255]。
  • graph 的输出:
    (1)若有后处理,frozen graph 会增加一个 TFLite_Detection_PostProcess custom op 节点包含 4 项输出,分别是:
    detection_boxes: a float32 tensor of shape [1, num_boxes, 4] with box locations
    detection_classes: a float32 tensor of shape [1, num_boxes] with class indices
    detection_scores: a float32 tensor of shape [1, num_boxes] with class scores
    num_boxes: a float32 tensor of size 1 containing the number of detected boxes
    (2)若没有后处理,frozen graph 输出 2 项,分别是:
    'raw_outputs/box_encodings': a float32 tensor of shape [1, num_anchors, 4] containing the encoded box predictions.
    'raw_outputs/class_predictions': a float32 tensor of shape [1, num_anchors, num_classes] containing the class scores for each anchor after applying score conversion.

main 调用的 pipeline_pb2.py

仍然不按顺序先说下 pipeline_pb2.py,它由 pipeline.proto 生成,proto 是谷歌的亲儿子,用来处理序列化和反序列化。

可以看下 pipeline.proto 里面的内容,里面主要包含了配置训练和评估模型的 pipeline,依次有 DetectionModel、TrainConfig、InputReader、EvalConfig、InputReader、GraphRewriter,涵盖了训练和评估会涉及的过程,这样我们用一个文件就能配置整个 pipeline。
TensorFlow TFLite 移动端(安卓)部署中 freeze 的细节:export_tflite_ssd_graph()_第1张图片
而所使用的 config 文件中也会依次包含 model、train_config、train_input_reader、eval_config、eval_input_reader 等内容。如果你想在 inference 中使用和 train 和 val 不一样的 config pipeline,还可以修改其中的某些部分。

看下 main 代码

使用示例:

python object_detection/export_tflite_ssd_graph.py \
    --pipeline_config_path path/to/ssd_mobilenet.config \
    --trained_checkpoint_prefix path/to/model.ckpt \
    --output_directory path/to/exported_model_directory

想修改 config 细节用于 inference 时:

#Example Usage (in which we change the NMS iou_threshold to be 0.5 and NMS score_threshold to be 0.0):
python object_detection/export_tflite_ssd_graph.py \
    --pipeline_config_path path/to/ssd_mobilenet.config \
    --trained_checkpoint_prefix path/to/model.ckpt \
    --output_directory path/to/exported_model_directory
    --config_override " \
            model{ \
            ssd{ \
              post_processing { \
                batch_non_max_suppression { \
                        score_threshold: 0.0 \
                        iou_threshold: 0.5 \
                } \
             } \
          } \
       } \
       "

代码主体:
export_tflite_ssd_graph.py 主体如下,核心代码仅一句,调用 export_tflite_ssd_graph_lib.py ,lib 下的内容也都比较重要,所以下个标题细看一下。

# 针对 ssd mobilenet v2,整体还是基于 v1 的,用 v2 compat v1
import tensorflow.compat.v1 as tf
from google.protobuf import text_format
from object_detection import export_tflite_ssd_graph_lib
# 控制训练和验证的 pipeline
from object_detection.protos import pipeline_pb2

flags = tf.app.flags
# 前 3 个 必须要指定
flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
flags.DEFINE_string('pipeline_config_path', None, 'Path to a pipeline_pb2.TrainEvalPipelineConfig config file.')
flags.DEFINE_string('trained_checkpoint_prefix', None, 'Checkpoint prefix.')

# 后面这些可以取默认值
flags.DEFINE_integer('max_detections', 10, 'Maximum number of detections (boxes) to show.')
flags.DEFINE_integer('max_classes_per_detection', 1, 'Maximum number of classes to output per detection box.')
flags.DEFINE_integer('detections_per_class', 100, 'Number of anchors used per class in Regular Non-Max-Suppression.')

# 默认有后处理
flags.DEFINE_bool('add_postprocessing_op', True, 'Add TFLite custom op for postprocessing to the graph.')
# 默认使用 Fast NMS
flags.DEFINE_bool('use_regular_nms', False, 'Flag to set postprocessing op to use Regular NMS instead of Fast NMS.')
# 默认不修改 config
flags.DEFINE_string('config_override', '', 'pipeline_pb2.TrainEvalPipelineConfig text proto to override pipeline_config_path.')

FLAGS = flags.FLAGS


def main(argv):
    del argv  # Unused.
    flags.mark_flag_as_required('output_directory')
    flags.mark_flag_as_required('pipeline_config_path')
    flags.mark_flag_as_required('trained_checkpoint_prefix')
    
    # 所有 config 相关操作
    # 实例化
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()	
    # 读取指定的 config 文件
    with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:	
        text_format.Merge(f.read(), pipeline_config)
    # 若指定 config override 会在这里合并
    text_format.Merge(FLAGS.config_override, pipeline_config)
    
    # 核心代码,就一行,调用了 export_tflite_ssd_graph_lib 的 export tflite graph
    export_tflite_ssd_graph_lib.export_tflite_graph(
        pipeline_config, 
        FLAGS.trained_checkpoint_prefix, 
        FLAGS.output_directory,
        FLAGS.add_postprocessing_op, 
        FLAGS.max_detections,
        FLAGS.max_classes_per_detection, 
        use_regular_nms=FLAGS.use_regular_nms)


if __name__ == '__main__':
    tf.app.run(main)

看下 export_tflite_ssd_graph_lib.py

lib 文件中一共 3 个函数,分别是:

  • 输出常量 center-size encoded anchors 的 get_const_center_size_encoded_anchors()
  • 增加 postprocessing op 的 append_postprocessing_op()
  • 被 main 直接调用的 export_tflite_graph()

export_tflite_graph()

整体代码过长,所以不会贴全流程,没必要,阅读感很差。

  • 首先,这个文件主要针对 ssd 模型,所以它首先会读取 pipeline config 检查是不是 ssd:
if pipeline_config.model.WhichOneof('model') != 'ssd':
	......
  • 之后,会读取 pipeline_config 文件中的其他信息,如 num_classes、nms_score_threshold、nms_iou_threshold、scale_values、image_resizer 等,本代码中只接受固定图片尺寸输入,如 300 × 300,这些参数在指定 add_postprocessing_op 时,会被添加到 model graph 中,后面会再涉及到:
num_classes = pipeline_config.model.ssd.num_classes
nms_score_threshold = {pipeline_config.model.ssd.post_processing.batch_non_max_suppression.score_threshold}
......
  • 指定输入节点名 input_arrays 为'normalized_input_image_tensor'
image = tf.placeholder(tf.float32, shape=shape, name='normalized_input_image_tensor')
  • 然后就会根据 config 中 model 的信息构建 build 模型,此函数支持 ssd、faster_rcnn、experimental_model、center_net 四种:
detection_model = model_builder.build(pipeline_config.model, is_training=False)
......

插播:为什么 ssd_mobilenet_v2 得用 tf1 环境?
上面 build 函数如下:

def build(model_config, is_training, add_summaries=True):
	...
	# ssd 情况就是这个文件中的 `_build_ssd_model()` 函数
    build_func = META_ARCH_BUILDER_MAP[meta_architecture]
    return build_func(getattr(model_config, meta_architecture), is_training,
                      add_summaries)

_build_ssd_model() 函数会检查 feature extractor 的 类型是不是在目前所使用的 tf 版本下对应的字典中:

_check_feature_extractor_exists(ssd_config.feature_extractor.type)
if tf_version.is_tf1():
  SSD_FEATURE_EXTRACTOR_CLASS_MAP = {
      'ssd_inception_v2':
          SSDInceptionV2FeatureExtractor,
	  ...
      'ssd_mobilenet_v2':
          SSDMobileNetV2FeatureExtractor,
      ...
      }

ssd_mobilenet_v2 仅存在于 tf1 对应的 feature extractor 中,因此使用 ssd_mobilenet_v2 就必须使用 tf1 环境。


  • 然后得到模型的 “预测” 结果,继而通过 config 中的激活函数得到预测的类别,在这里是 sigmoid 函数:
predicted_tensors = detection_model.predict(image, true_image_shapes=None)
_, score_conversion_fn = post_processing_builder.build(pipeline_config.model.ssd.post_processing)
class_predictions = score_conversion_fn(predicted_tensors['class_predictions_with_background'])
  • 创建输出节点:
with tf.name_scope('raw_outputs'):
    tf.identity(predicted_tensors['box_encodings'], name='box_encodings')
    tf.identity(class_predictions, name='class_predictions')
tf.identity(get_const_center_size_encoded_anchors(predicted_tensors['anchors']), name='anchors')
    
  • 接着就是对模型进行 frozen,使用 exporter.freeze_graph_with_def_protos():
frozen_graph_def = exporter.freeze_graph_with_def_protos(...)
  • 而这个函数继续跳转查看下去会走到 tensorflow.python.toolsfreeze_graph 中的 freeze_graph_with_def_protos,其中最核心的 freeze 代码为:
output_graph_def = graph_util.convert_variables_to_constants(...)
  • 如果需要 add_postprocessing_op,还会在 modes graph 中增加一个输出节点 'TFLite_Detection_PostProcess',它有 4 个输出,这也是 freeze 后面 convert 步骤中的 output_arrays
    if add_postprocessing_op:
        transformed_graph_def = append_postprocessing_op(...)
def append_postprocessing_op(...):
    new_output = frozen_graph_def.node.add()
    new_output.op = 'TFLite_Detection_PostProcess'
    new_output.name = 'TFLite_Detection_PostProcess'
    ...
  • 'TFLite_Detection_PostProcess' 的具体内容包括 max_detections、max_classes_per_detection、nms_score_threshold、nms_iou_threshold、num_classes、scale_values、detections_per_class 等 main 中指定的参数以及在 config 文件中读取的一些细节参数。

  • 至此,已经基本查看了整个代码中的核心流程,也就是完成了模型参数的常量化/固化/freeze,最后写入文件:

    binary_graph = os.path.join(output_dir, binary_graph_name)
    with tf.gfile.GFile(binary_graph, 'wb') as f:
        f.write(transformed_graph_def.SerializeToString())
        
    txt_graph = os.path.join(output_dir, txt_graph_name)
    with tf.gfile.GFile(txt_graph, 'w') as f:
        f.write(str(transformed_graph_def))

你可能感兴趣的:(#,移动端物体检测,tensorflow,边缘计算)