1.基于tflite实现前端检测算法并生成app

 

1 配置环境

1.1安装tensorflow+cuda+cudnn+anaconda

配置tensorflow环境,最好安装tensorflow1.X,2以上可能会遇到训练问题。

1.2 配置tensorflow object detection API

模型训练使用tensorflow object detection API。

具体配置方法可以参考如下链接配置:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md

git clone https://github.com/tensorflow/models

cd models/research sudo apt-get install protobuf-compiler python-pil python-lxml protoc object_detection/protos/*.proto --python_out=. export PYTHONPATH=$PYTHONPATH:pwd:pwd/slim

运行如下代码,若输出OK则配置成功

python object_detection/builders/model_builder_test.py

Ran 18 tests in 0.079s

OK

注意每次export PYTHONPATH=$PYTHONPATH:pwd:pwd/slim 再使用API

或者: gedit ~/.bashrc 添加如下两行,根据具体路径修改

export PYTHONPATH="/home/andy/anaconda3/lib/python3.7/site-packages/tensorflow/models/research:$PYTHONPATH" export PYTHONPATH="/home/andy/anaconda3/lib/python3.7/site-packages/tensorflow/models/research/slim:$PYTHONPATH"

#使修改立即生效

source ~/.bashrc

若遇到如下问题,很可能是tensorflow2.0+版本,可以回退到低版本

AttributeError: module 'tensorflow' has no attribute 'contrib'

2.数据标注与预处理

2.1 使用labelimg标注图像

首先需要安装labelimg标注工具,最近发现了一个挺好用的标注工具“精灵标注助手”,可以使用该工具替代labelimg.若想使用labelimg标注,可以使用如下方式安装:

进入https://github.com/tzutalin/labelImg,下载工程文件 解压并在该工程路径下执行如下命令:

sudo apt-get install pyqt5-dev-tools

pip install -r requirements/requirements-linux-python3.txt

make qt5py3

python labelImg.py

安装完成后,使用VOC格式,开始标注数据

2.2 将数据转换为tfrecord

下载官方代码:https://github.com/tensorflow/models

1)预处理

批量转换数据大小:

代码连接:https://github.com/italojs/resize_dataset_pascalvoc 进入代码路径 cd /home/light/abaowork/tflite/data/data_pre/resize_dataset_pascalvoc-master/

python main.py -p /home/andy/data/all --output ./train --new_x 480 --new_y 640 --save_box_images 1

2)将.xml文件转换成csv文件

1.划分数据集 2.运行xml_to_csv.py(自主实现),将在images文件夹下产生两个.csv文件,分别为train_labels.csv和test_labels.csv

3)转换tfrecord

修改generate_tfrecord.py中的类别名称,根据具体自己训练集修改名称,或增减类别量 def class_text_to_int(row_label): if row_label == 'nine': return 1 elif row_label == 'ten': return 2

执行代码

export PYTHONPATH=$PYTHONPATH:pwd:pwd/slim python object_detection/generate_tfrecord.py --csv_input=object_detection/data_card/train_labels.csv --image_dir=object_detection/data_card/train --output_path=object_detection/data_card/train.record

python object_detection/generate_tfrecord.py --csv_input=object_detection/data_card/test_labels.csv --image_dir=object_detection/data_card/test --output_path=object_detection/data_card/test.record

3 模型训练

3.1 定义pet_label_map.pbtxt中的类别信息

3.2 下载ssd的预训练模型

下载地址:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md 模型地址在3.3的文件中配置

3.3 配置训练参数

ssd_mobilenet_v1_0.75_depth_quantized_300x300_pets_sync.config

3.4 模型训练


python object_detection/model_main.py --logtostderr --model_dir=object_detection/training/ --pipeline_config_path=/home/light/models/research/object_detection/training/ssd_mobilenet_v1_0.75_depth_quantized_300x300_pets_sync.config

查看训练情况:

tensorboard --logdir=./ watch -n 1 nvidia-smi

4.模型文件转换

4.1转换为pb格式

python object_detection/export_tflite_ssd_graph.py \ --pipeline_config_path=/home/andy/models/research/object_detection/training/ssd_mobilenet_v1_0.75_depth_quantized_300x300_pets_sync.config \ --trained_checkpoint_prefix=/home/andy/models/research/object_detection/training/model.ckpt-287 \ --output_directory=object_detection/inference_graph_tflite \ --add_postprocessing_op=true

4.2转换为tflite格式

1)下载tensorflow源代码:https://github.com/tensorflow/tensorflow

2)进入源码路径下:/home/light/abaowork/tflite/models/research/tensorflow-master/tensorflow-master 安装bazel https://docs.bazel.build/versions/master/install-ubuntu.html

编译 bazel

build tensorflow/lite/toco:toco

3)转换为tflite格式


./bazel-bin/tensorflow/lite/toco/toco \ --input_file=/home/andy/models/research/object_detection/inference_graph_tflite/tflite_graph.pb \ --output_file=/home/andy/models/research/object_detection/inference_graph_tflite/detect.tflite \ --input_shapes=1,300,300,3 \ --input_arrays=normalized_input_image_tensor \ --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \ --inference_type=QUANTIZED_UINT8 \ --mean_values=128 \ --std_values=128 \ --change_concat_input_ranges=false \ --allow_custom_ops

5 安卓/ios集成

启动android-studio /data/下载文件/android-studio/bin/studio.sh 下载工程示例https://github.com/tensorflow/examples 屏蔽下载模型程序,替换DetectorActivity.java 模型和标签

手机连接电脑,打开开发者模式,启动usb

运行,最后会在手机端生成app,根据不同的数据集,可以实现不同的检测任务

参考连接:

1.关于tensorflow lite 的介绍:https://tensorflow.google.cn/lite/

2.安卓工程项目demo:https://github.com/tensorflow/examples/blob/master/lite/examples/object_detection/android/README.md

3.转换图像大小和对应标注:https://github.com/italojs/resize_dataset_pascalvoc

4.CSDN关于tensorflow object detection API的介绍https://blog.csdn.net/csdn_6105/article/details/82933628

5.官方教程【重点参考】需要外网:https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193

6.配置object detection API环境:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md

欢迎加我好友,交流学习:微信(abaofight)

你可能感兴趣的:(移动端)