本篇是 详解 TensorFlow TFLite 移动端(安卓)部署物体检测 demo(2)——量化模型 的辅助篇,重点讨论:
涉及到的主要代码和文件:
不特意提及的均在 /models/research/object_detection
下。
./export_tflite_ssd_graph.py
./export_tflite_ssd_graph_lib.py
./protos/pipeline_pb2.py
,由对应的 .proto 文件编译生成。./exporter.py
./builders/graph_rewriter_builder.py
./builders/model_builder.py
./builders/post_processing_builder.py
./core/box_list.py
./utils/tf_version.py
其中,我想打破顺序先说一下的,是最后一个 utils 代码 tf_version
,用来判断你所使用的 tf 版本是 1 还是 2,其内容非常简单,如下:
"""Functions to check TensorFlow Version."""
from tensorflow.python import tf2 # pylint: disable=import-outside-toplevel
def is_tf1():
"""Whether current TensorFlow Version is 1.X."""
return not tf2.enabled()
def is_tf2():
"""Whether current TensorFlow Version is 2.X."""
return tf2.enabled()
因为 tf 1 和 2 之间有些不同之处,所以 models
项目有根据 tf 版本做适配。适配整体分两种情况:
if tf_version.is_tf1():
from tensorflow.tools.graph_transforms import TransformGraph # pylint: disable=g-import-not-at-top
model_buider.py
中的 _check_feature_extractor_exists
函数会根据 tf 版本的不同来判断当前 tf 版本是否支持 feature extractor
。./export_tflite_ssd_graph.py
tflite_graph.pbtxt
和 tflite_graph.pb
。而 main 代码输出 graph 的输入输出直接决定了 freeze 之后 convert 到 tflite 时 节点名称,输入输出名称是本篇中相关代码中写好的,也就是说如果你改了相关代码中的节点名,在 convert 那一步就需要也跟着改就好了:
'normalized_input_image_tensor'
: a float32 tensor of shape [1, height, width, 3] containing the normalized input image. 对于浮点型 model,范围为 [-1, 1);对于量化后 model,范围为 [0, 255]。TFLite_Detection_PostProcess custom op
节点包含 4 项输出,分别是:detection_boxes
: a float32 tensor of shape [1, num_boxes, 4] with box locationsdetection_classes
: a float32 tensor of shape [1, num_boxes] with class indicesdetection_scores
: a float32 tensor of shape [1, num_boxes] with class scoresnum_boxes
: a float32 tensor of size 1 containing the number of detected boxes'raw_outputs/box_encodings'
: a float32 tensor of shape [1, num_anchors, 4] containing the encoded box predictions.'raw_outputs/class_predictions'
: a float32 tensor of shape [1, num_anchors, num_classes] containing the class scores for each anchor after applying score conversion.pipeline_pb2.py
仍然不按顺序先说下 pipeline_pb2.py
,它由 pipeline.proto
生成,proto 是谷歌的亲儿子,用来处理序列化和反序列化。
可以看下 pipeline.proto
里面的内容,里面主要包含了配置训练和评估模型的 pipeline,依次有 DetectionModel、TrainConfig、InputReader、EvalConfig、InputReader、GraphRewriter,涵盖了训练和评估会涉及的过程,这样我们用一个文件就能配置整个 pipeline。
而所使用的 config 文件中也会依次包含 model、train_config、train_input_reader、eval_config、eval_input_reader 等内容。如果你想在 inference 中使用和 train 和 val 不一样的 config pipeline,还可以修改其中的某些部分。
使用示例:
python object_detection/export_tflite_ssd_graph.py \
--pipeline_config_path path/to/ssd_mobilenet.config \
--trained_checkpoint_prefix path/to/model.ckpt \
--output_directory path/to/exported_model_directory
想修改 config 细节用于 inference 时:
#Example Usage (in which we change the NMS iou_threshold to be 0.5 and NMS score_threshold to be 0.0):
python object_detection/export_tflite_ssd_graph.py \
--pipeline_config_path path/to/ssd_mobilenet.config \
--trained_checkpoint_prefix path/to/model.ckpt \
--output_directory path/to/exported_model_directory
--config_override " \
model{ \
ssd{ \
post_processing { \
batch_non_max_suppression { \
score_threshold: 0.0 \
iou_threshold: 0.5 \
} \
} \
} \
} \
"
代码主体:
export_tflite_ssd_graph.py
主体如下,核心代码仅一句,调用 export_tflite_ssd_graph_lib.py
,lib 下的内容也都比较重要,所以下个标题细看一下。
# 针对 ssd mobilenet v2,整体还是基于 v1 的,用 v2 compat v1
import tensorflow.compat.v1 as tf
from google.protobuf import text_format
from object_detection import export_tflite_ssd_graph_lib
# 控制训练和验证的 pipeline
from object_detection.protos import pipeline_pb2
flags = tf.app.flags
# 前 3 个 必须要指定
flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
flags.DEFINE_string('pipeline_config_path', None, 'Path to a pipeline_pb2.TrainEvalPipelineConfig config file.')
flags.DEFINE_string('trained_checkpoint_prefix', None, 'Checkpoint prefix.')
# 后面这些可以取默认值
flags.DEFINE_integer('max_detections', 10, 'Maximum number of detections (boxes) to show.')
flags.DEFINE_integer('max_classes_per_detection', 1, 'Maximum number of classes to output per detection box.')
flags.DEFINE_integer('detections_per_class', 100, 'Number of anchors used per class in Regular Non-Max-Suppression.')
# 默认有后处理
flags.DEFINE_bool('add_postprocessing_op', True, 'Add TFLite custom op for postprocessing to the graph.')
# 默认使用 Fast NMS
flags.DEFINE_bool('use_regular_nms', False, 'Flag to set postprocessing op to use Regular NMS instead of Fast NMS.')
# 默认不修改 config
flags.DEFINE_string('config_override', '', 'pipeline_pb2.TrainEvalPipelineConfig text proto to override pipeline_config_path.')
FLAGS = flags.FLAGS
def main(argv):
del argv # Unused.
flags.mark_flag_as_required('output_directory')
flags.mark_flag_as_required('pipeline_config_path')
flags.mark_flag_as_required('trained_checkpoint_prefix')
# 所有 config 相关操作
# 实例化
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
# 读取指定的 config 文件
with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
text_format.Merge(f.read(), pipeline_config)
# 若指定 config override 会在这里合并
text_format.Merge(FLAGS.config_override, pipeline_config)
# 核心代码,就一行,调用了 export_tflite_ssd_graph_lib 的 export tflite graph
export_tflite_ssd_graph_lib.export_tflite_graph(
pipeline_config,
FLAGS.trained_checkpoint_prefix,
FLAGS.output_directory,
FLAGS.add_postprocessing_op,
FLAGS.max_detections,
FLAGS.max_classes_per_detection,
use_regular_nms=FLAGS.use_regular_nms)
if __name__ == '__main__':
tf.app.run(main)
export_tflite_ssd_graph_lib.py
lib 文件中一共 3 个函数,分别是:
get_const_center_size_encoded_anchors()
;append_postprocessing_op()
;export_tflite_graph()
;整体代码过长,所以不会贴全流程,没必要,阅读感很差。
if pipeline_config.model.WhichOneof('model') != 'ssd':
......
num_classes = pipeline_config.model.ssd.num_classes
nms_score_threshold = {pipeline_config.model.ssd.post_processing.batch_non_max_suppression.score_threshold}
......
'normalized_input_image_tensor'
:image = tf.placeholder(tf.float32, shape=shape, name='normalized_input_image_tensor')
detection_model = model_builder.build(pipeline_config.model, is_training=False)
......
插播:为什么 ssd_mobilenet_v2 得用 tf1 环境?
上面 build 函数如下:
def build(model_config, is_training, add_summaries=True):
...
# ssd 情况就是这个文件中的 `_build_ssd_model()` 函数
build_func = META_ARCH_BUILDER_MAP[meta_architecture]
return build_func(getattr(model_config, meta_architecture), is_training,
add_summaries)
而 _build_ssd_model()
函数会检查 feature extractor 的 类型是不是在目前所使用的 tf 版本下对应的字典中:
_check_feature_extractor_exists(ssd_config.feature_extractor.type)
if tf_version.is_tf1():
SSD_FEATURE_EXTRACTOR_CLASS_MAP = {
'ssd_inception_v2':
SSDInceptionV2FeatureExtractor,
...
'ssd_mobilenet_v2':
SSDMobileNetV2FeatureExtractor,
...
}
而 ssd_mobilenet_v2
仅存在于 tf1 对应的 feature extractor 中,因此使用 ssd_mobilenet_v2
就必须使用 tf1 环境。
predicted_tensors = detection_model.predict(image, true_image_shapes=None)
_, score_conversion_fn = post_processing_builder.build(pipeline_config.model.ssd.post_processing)
class_predictions = score_conversion_fn(predicted_tensors['class_predictions_with_background'])
with tf.name_scope('raw_outputs'):
tf.identity(predicted_tensors['box_encodings'], name='box_encodings')
tf.identity(class_predictions, name='class_predictions')
tf.identity(get_const_center_size_encoded_anchors(predicted_tensors['anchors']), name='anchors')
frozen_graph_def = exporter.freeze_graph_with_def_protos(...)
tensorflow.python.tools
下 freeze_graph
中的 freeze_graph_with_def_protos
,其中最核心的 freeze 代码为:output_graph_def = graph_util.convert_variables_to_constants(...)
'TFLite_Detection_PostProcess'
,它有 4 个输出,这也是 freeze 后面 convert 步骤中的 output_arrays: if add_postprocessing_op:
transformed_graph_def = append_postprocessing_op(...)
def append_postprocessing_op(...):
new_output = frozen_graph_def.node.add()
new_output.op = 'TFLite_Detection_PostProcess'
new_output.name = 'TFLite_Detection_PostProcess'
...
而 'TFLite_Detection_PostProcess'
的具体内容包括 max_detections、max_classes_per_detection、nms_score_threshold、nms_iou_threshold、num_classes、scale_values、detections_per_class 等 main 中指定的参数以及在 config 文件中读取的一些细节参数。
至此,已经基本查看了整个代码中的核心流程,也就是完成了模型参数的常量化/固化/freeze,最后写入文件:
binary_graph = os.path.join(output_dir, binary_graph_name)
with tf.gfile.GFile(binary_graph, 'wb') as f:
f.write(transformed_graph_def.SerializeToString())
txt_graph = os.path.join(output_dir, txt_graph_name)
with tf.gfile.GFile(txt_graph, 'w') as f:
f.write(str(transformed_graph_def))