下载tensorflow源码
git clone https://github.com/tensorflow/tensorflow.git
下载tensorflow model
git clone https://github.com/tensorflow/models.git
下载并安装bazel编译工具
https://docs.bazel.build/versions/master/install.html
下载tensorflow的demo
git clone https://github.com/tensorflow/examples.git
相关包和protobuf
pip install matplotlib pillow lxml Cython pycocotools
sudo apt-get install protobuf-compiler
protobuf编译
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
在.bashrc里面添加环境变量
export PYTHONPATH=$PYTHONPATH:/home/wxy/tensorflow_model_new/models/research:/home/wxy/tensorflow_model_new/models/research/slim
测试tensorflow model是否安装成功
python object_detection/builders/model_builder_test.py
TensorFlow Object Detection API训练自己的数据集
将数据做成VOC标准格式,在object_detection目录下运行,移植到移动端最好要用ssd mobilenet,我下面的例子是之前faster rcnn的,其实是一个道理。
python create_pascal_tf_record1.py --data_dir=/home/wxy/Faster_RCNN/dataset/VOCdevkit --year=VOC2012 --set=val --output_path=/home/wxy/Faster_RCNN/dataset/FTrecord/data/pascal_val.record --label_map_path=/home/wxy/Faster_RCNN/dataset/pascal_label_map.pbtxt
python create_pascal_tf_record1.py --data_dir=/home/wxy/Faster_RCNN/dataset/VOCdevkit --year=VOC2012 --set=train --output_path=/home/wxy/Faster_RCNN/dataset/FTrecord/data/pascal_train.record --label_map_path=/home/wxy/Faster_RCNN/dataset/pascal_label_map.pbtxt
python train.py --train_dir /home/wxy/Faster_RCNN/train_dir_/ --pipeline_config_path /home/wxy/Faster_RCNN/dataset/faster_rcnn_inception_resnet_v2_atrous_voc.config
python export_inference_graph.py --input_type image_tensor --pipeline_config_path /home/wxy/Faster_RCNN/dataset/faster_rcnn_inception_resnet_v2_atrous_voc.config --trained_checkpoint_prefix /home/wxy/Faster_RCNN/train_dir_/model.ckpt-100000 --output_directory /home/wxy/Faster_RCNN/output
可能是源的问题,third_party下载很慢,可能会中断,多运行几次就好了
bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/lite/toco:toco
官方的方法:
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_on_mobile_tensorflowlite.md
模型可以在这里找到有很多
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
我用的是ssd_mobilenet_v2_coco_2018_03_29,pipeline.config里面要改个东西,好像是BN层相关的,删一句话就行了,因为tensoflow model更新的比较快。
object_detection目录下
python export_tflite_ssd_graph.py \
--pipeline_config_path=/home/wxy/tensorflow_model/models/research/object_detection/ssd_mobilenet_v2_coco_2018_03_29/pipeline.config \
--trained_checkpoint_prefix=/home/wxy/tensorflow_model/models/research/object_detection/ssd_mobilenet_v2_coco_2018_03_29/model.ckpt \
--output_directory=/home/wxy/模型压缩/ssd_mobilenet/mobilenet_ssd_lite \
--add_postprocessing_op=true
tensorflow源码目录下
float32
bazel run tensorflow/lite/toco:toco -- \
--input_file=/home/wxy/模型压缩/ssd_mobilenet/mobilenet_ssd_lite/tflite_graph.pb \
--output_file=/home/wxy/模型压缩/ssd_mobilenet/detect.tflite \
--input_shapes=1,300,300,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
--inference_type=FLOAT \
--allow_custom_ops
--mean_values=128 \
--std_dev_values=128 \
--default_ranges_min=0 \
uint8量化
bazel-bin/tensorflow/lite/toco/toco \
--graph_def_file=/home/wxy/模型压缩/ssd_mobilenet/mobilenet_ssd_lite/tflite_graph.pb \
--output_file=/home/wxy/模型压缩/ssd_mobilenet/detect.tflite \
--input_shapes=1,300,300,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
--inference_type=QUANTIZED_UINT8 \
--mean_values=128 \
--std_dev_values=128 \
--default_ranges_min=0 \
--default_ranges_max=6 \
--change_concat_input_ranges=False \
Android Studio安装
ubuntu软件里面自带,下载安装很方便
下载安装Android SDK
有一些默认安装的,还有一个NDK,在Android Studio启动界面有个Configure,然后找到Android SDK,SDKTools里面找到NDK下载并安装
生成.apk
Android Studio打开以下路径,Build生成.apk即可
/home/wxy/tensorflow_lite/examples/lite/examples/object_detection/android
打开以下路径找到一个app-debug.apk,拿到安卓端安装即可测试
/home/wxy/tensorflow_lite/examples/lite/examples/object_detection/android/app/build/outputs/apk/debug
使用自己的模型
将自己相应的detect.tflite和labelmap.txt放在以下路径,因为demo里最开始没有detect.tflite,如果找不到会默认下载官方的detect.tflite,官方的模型很小才4.2M,Google牛逼。
/home/wxy/tensorflow_lite/examples/lite/examples/object_detection/android/app/src/main/assets
测试
小米Note,2015年的手机,检测速度200ms,效果还不错