tensorflow lite将ssd_mobilenet移植至移动端

准备工作

下载tensorflow源码

git clone https://github.com/tensorflow/tensorflow.git

下载tensorflow model

git clone https://github.com/tensorflow/models.git

下载并安装bazel编译工具
https://docs.bazel.build/versions/master/install.html
下载tensorflow的demo

git clone https://github.com/tensorflow/examples.git

使用TensorFlow Object Detection API

相关包和protobuf

pip install matplotlib pillow lxml Cython pycocotools
sudo apt-get install protobuf-compiler

protobuf编译

cd models/research/
protoc object_detection/protos/*.proto --python_out=.

在.bashrc里面添加环境变量

export PYTHONPATH=$PYTHONPATH:/home/wxy/tensorflow_model_new/models/research:/home/wxy/tensorflow_model_new/models/research/slim

测试tensorflow model是否安装成功

python object_detection/builders/model_builder_test.py

TensorFlow Object Detection API训练自己的数据集
将数据做成VOC标准格式,在object_detection目录下运行,移植到移动端最好要用ssd mobilenet,我下面的例子是之前faster rcnn的,其实是一个道理。

python create_pascal_tf_record1.py --data_dir=/home/wxy/Faster_RCNN/dataset/VOCdevkit --year=VOC2012 --set=val --output_path=/home/wxy/Faster_RCNN/dataset/FTrecord/data/pascal_val.record --label_map_path=/home/wxy/Faster_RCNN/dataset/pascal_label_map.pbtxt

python create_pascal_tf_record1.py --data_dir=/home/wxy/Faster_RCNN/dataset/VOCdevkit --year=VOC2012 --set=train --output_path=/home/wxy/Faster_RCNN/dataset/FTrecord/data/pascal_train.record --label_map_path=/home/wxy/Faster_RCNN/dataset/pascal_label_map.pbtxt

python train.py --train_dir /home/wxy/Faster_RCNN/train_dir_/ --pipeline_config_path /home/wxy/Faster_RCNN/dataset/faster_rcnn_inception_resnet_v2_atrous_voc.config

python export_inference_graph.py --input_type image_tensor --pipeline_config_path /home/wxy/Faster_RCNN/dataset/faster_rcnn_inception_resnet_v2_atrous_voc.config --trained_checkpoint_prefix /home/wxy/Faster_RCNN/train_dir_/model.ckpt-100000 --output_directory /home/wxy/Faster_RCNN/output

用bazel编译tensorflow源码

可能是源的问题,third_party下载很慢,可能会中断,多运行几次就好了

bazel build tensorflow/python/tools:freeze_graph
 
bazel build tensorflow/lite/toco:toco

导出模型

官方的方法:
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_on_mobile_tensorflowlite.md

模型可以在这里找到有很多
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
我用的是ssd_mobilenet_v2_coco_2018_03_29,pipeline.config里面要改个东西,好像是BN层相关的,删一句话就行了,因为tensoflow model更新的比较快。
object_detection目录下

python export_tflite_ssd_graph.py \
--pipeline_config_path=/home/wxy/tensorflow_model/models/research/object_detection/ssd_mobilenet_v2_coco_2018_03_29/pipeline.config \
--trained_checkpoint_prefix=/home/wxy/tensorflow_model/models/research/object_detection/ssd_mobilenet_v2_coco_2018_03_29/model.ckpt \
--output_directory=/home/wxy/模型压缩/ssd_mobilenet/mobilenet_ssd_lite \
--add_postprocessing_op=true

tensorflow源码目录下
float32

bazel run tensorflow/lite/toco:toco -- \
--input_file=/home/wxy/模型压缩/ssd_mobilenet/mobilenet_ssd_lite/tflite_graph.pb \
--output_file=/home/wxy/模型压缩/ssd_mobilenet/detect.tflite \
--input_shapes=1,300,300,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
--inference_type=FLOAT \
--allow_custom_ops
--mean_values=128 \
--std_dev_values=128 \
--default_ranges_min=0 \

uint8量化

bazel-bin/tensorflow/lite/toco/toco \
--graph_def_file=/home/wxy/模型压缩/ssd_mobilenet/mobilenet_ssd_lite/tflite_graph.pb \
--output_file=/home/wxy/模型压缩/ssd_mobilenet/detect.tflite \
--input_shapes=1,300,300,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
--inference_type=QUANTIZED_UINT8 \
--mean_values=128 \
--std_dev_values=128 \
--default_ranges_min=0 \
--default_ranges_max=6 \
--change_concat_input_ranges=False \

移植到安卓端

Android Studio安装
ubuntu软件里面自带,下载安装很方便
下载安装Android SDK
有一些默认安装的,还有一个NDK,在Android Studio启动界面有个Configure,然后找到Android SDK,SDKTools里面找到NDK下载并安装
生成.apk
Android Studio打开以下路径,Build生成.apk即可

/home/wxy/tensorflow_lite/examples/lite/examples/object_detection/android

打开以下路径找到一个app-debug.apk,拿到安卓端安装即可测试

/home/wxy/tensorflow_lite/examples/lite/examples/object_detection/android/app/build/outputs/apk/debug

使用自己的模型
将自己相应的detect.tflite和labelmap.txt放在以下路径,因为demo里最开始没有detect.tflite,如果找不到会默认下载官方的detect.tflite,官方的模型很小才4.2M,Google牛逼。

/home/wxy/tensorflow_lite/examples/lite/examples/object_detection/android/app/src/main/assets

测试
小米Note,2015年的手机,检测速度200ms,效果还不错

你可能感兴趣的:(tensorflow,深度学习,目标检测)