Tensorflow 模型转 tflite ,在安卓端使用

自己在将tensorflow模型移动端部署的时候(使用 tensorflow lite),踩了很多坑,查了很多资料,现在做个记录,所有参考资料在文章最后 参考 处列出。

tensorflow lite是TensorFlow Lite 是 Google I/O 2017 大会上的其中一个重要宣布,有了TensorFlow Lite,应用开发者可以在移动设备上部署人工智能。
tensorflow lite 【github】

Tensorflow 模型转 tflite ,在安卓端使用_第1张图片

基本思路:

  1. 在pc端进行 Tensorflow 模型训练,保存训练模型
  2. 使用 工具将该模型转换为 Tensorflow lite 模型
  3. 在Android上使用

tensorflow模型持久化

在tensorflow中进行模型训练,得到适合自己项目的模型。Tensorflow 模型训练好之后会生成三个文件:

  • model.ckpt.meta :保存Tensorflow计算图结构,可以理解为神经网络的网络结构
  • model.ckpt :保存Tensorflow程序中每一个变量的取值,变量是模型中可训练的部分
  • checkpoint :保存一个目录下所有模型文件列表
# 使用tf.train.write_graph导出GraphDef文件
tf.train.write_graph(sess.graph_def, "./", "mz_graph.pb", as_text=False)
# 使用tf.train.save导出checkpoint文件
saver.save(sess, model_path)

生成的模型文件如下图所示:
Tensorflow 模型转 tflite ,在安卓端使用_第2张图片

bazel编译需要的工具

Tensoflow使用的编译工具是 bazel,谷歌开源的自动化构建工具。【bazel传送门】
安装bazel,用来编译 tensorflow 转 tflite 时用到的几个工具,freeze、toco、summarize_graph(具体作用下面说),这些工具都在 tensorflow(从github上clone) 中,按下面命令进行编译(在 tensorflow目录下进行):

bazel build tensorflow/python/tools:freeze_graph

bazel build tensorflow/contrib/lite/toto:toto

Bazel build tensorflow/tools/graph_transforms:summarize_graph  (查看模型结构,找出输入输出)

模型转换

将训练好的tf模型,进行freeze、toco操作,freeze主要是将 tensorflow模型持久化 中生成的文件进行合并,得到一个变量值和运算图模型相结合的文件,是将变量值固定在图中的操作。如上图,这步生成 mz_freezegraph.pb .

summarize_graph

该命令查看整个Tensorflow模型概况,使用命令如下,运行之后,得到自己整个网络结构,从中可以找到自己模型的输入输出,如下图(模型比较乱。。。)

# --in_graph=” 后面是模型存储的位置
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=../mz_graph.pb

Tensorflow 模型转 tflite ,在安卓端使用_第3张图片

freeze_graph

该命令是 Tensorflow模型固化,将Tensorflow模型和计算图上变量的值合二为一,方便直接转换 Tensorflow lite 模型。

    bazel-bin/tensorflow/python/tools/freeze_graph\
        --input_graph=/tmp/mobilenet_v1_224.pb \
        --input_checkpoint=/tmp/checkpoints/mobilenet-10202.ckpt \
        --input_binary=true \
        --output_graph=/tmp/frozen_mobilenet_v1_224.pb \
        --output_node_names=MobileNet/Predictions/Reshape_1
  • input_graph :Tensorflow 模型结构文件
  • input_checkpoint :Tensorflow 模型 ckpt 文件
  • output_graph :输出的freeze文件
  • output_node_names :模型输出节点名字,使用 summarize_graph 查看 ,可以在 Tensorflow 网络训练时进行命名

这里写图片描述

toco

固化模型到 tflite 模型转化

toco --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
      --input_format=TENSORFLOW_GRAPHDEF \
      --output_format=TFLITE \
      --output_file=/tmp/mobilenet_v1_1.0_224.tflite \
      --inference_type=FLOAT \
      --input_type=FLOAT \
      --input_arrays=input \
      --output_arrays=MobilenetV1/Predictions/Reshape_1 \
      --input_shapes=1,224,224,3
  • input_file : freeze 之后的 Tensorflow 模型文件
  • output_file :转换好的 Tensorflow lite 模型,扩展名为 .tflite
  • output_arrays :仍然是Tensorflow 模型的输出
  • input_shapes :输入图片的维度

这里写图片描述

部署Android

1、安装 官方GitHub进行Android软件搭建 Tensorflow lite 【Github】
2、工程中有 FloatQuantized 两个模式可选,如下图,这里使用Float,Quantized需要先量化模型,在进行 tflite 模型转换。
3、将生成的 .tflite 文件和 对应的 labels.txt 文件放入Android工程的 assets 文件中。
4、运行即可。

Tensorflow 模型转 tflite ,在安卓端使用_第4张图片

参考

  1. TensorFlow Lite学习笔记2:生成TFLite模型文件
  2. TensorFlow固化模型
  3. TensorFlow Lite模型生成以及bazel的安装使用、出现的问题及解决方案整合
  4. Tensorflow Lite之编译生成tflite文件
  5. tensorflow Lite的使用
  6. tensorflow模型量化
  7. 用 TensorFlow 压缩神经网络
  8. 在Android上使用TensorFlow Lite

你可能感兴趣的:(深度学习)