配置tensorflow环境,最好安装tensorflow1.X,2以上可能会遇到训练问题。
模型训练使用tensorflow object detection API。
具体配置方法可以参考如下链接配置:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md
git clone https://github.com/tensorflow/models
cd models/research sudo apt-get install protobuf-compiler python-pil python-lxml protoc object_detection/protos/*.proto --python_out=. export PYTHONPATH=$PYTHONPATH:pwd:pwd/slim
运行如下代码,若输出OK则配置成功
python object_detection/builders/model_builder_test.py
Ran 18 tests in 0.079s
OK
注意每次export PYTHONPATH=$PYTHONPATH:pwd:pwd/slim
再使用API
或者: gedit ~/.bashrc 添加如下两行,根据具体路径修改
export PYTHONPATH="/home/andy/anaconda3/lib/python3.7/site-packages/tensorflow/models/research:$PYTHONPATH" export PYTHONPATH="/home/andy/anaconda3/lib/python3.7/site-packages/tensorflow/models/research/slim:$PYTHONPATH"
#使修改立即生效
source ~/.bashrc
若遇到如下问题,很可能是tensorflow2.0+版本,可以回退到低版本
AttributeError: module 'tensorflow' has no attribute 'contrib'
首先需要安装labelimg标注工具,最近发现了一个挺好用的标注工具“精灵标注助手”,可以使用该工具替代labelimg.若想使用labelimg标注,可以使用如下方式安装:
进入https://github.com/tzutalin/labelImg,下载工程文件 解压并在该工程路径下执行如下命令:
sudo apt-get install pyqt5-dev-tools
pip install -r requirements/requirements-linux-python3.txt
make qt5py3
python labelImg.py
安装完成后,使用VOC格式,开始标注数据
下载官方代码:https://github.com/tensorflow/models
1)预处理
批量转换数据大小:
代码连接:https://github.com/italojs/resize_dataset_pascalvoc 进入代码路径 cd /home/light/abaowork/tflite/data/data_pre/resize_dataset_pascalvoc-master/
python main.py -p /home/andy/data/all --output ./train --new_x 480 --new_y 640 --save_box_images 1
2)将.xml文件转换成csv文件
1.划分数据集 2.运行xml_to_csv.py(自主实现),将在images文件夹下产生两个.csv文件,分别为train_labels.csv和test_labels.csv
3)转换tfrecord
修改generate_tfrecord.py中的类别名称,根据具体自己训练集修改名称,或增减类别量 def class_text_to_int(row_label): if row_label == 'nine': return 1 elif row_label == 'ten': return 2
执行代码
export PYTHONPATH=$PYTHONPATH:pwd
:pwd
/slim python object_detection/generate_tfrecord.py --csv_input=object_detection/data_card/train_labels.csv --image_dir=object_detection/data_card/train --output_path=object_detection/data_card/train.record
python object_detection/generate_tfrecord.py --csv_input=object_detection/data_card/test_labels.csv --image_dir=object_detection/data_card/test --output_path=object_detection/data_card/test.record
下载地址:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md 模型地址在3.3的文件中配置
ssd_mobilenet_v1_0.75_depth_quantized_300x300_pets_sync.config
python object_detection/model_main.py --logtostderr --model_dir=object_detection/training/ --pipeline_config_path=/home/light/models/research/object_detection/training/ssd_mobilenet_v1_0.75_depth_quantized_300x300_pets_sync.config
查看训练情况:
tensorboard --logdir=./ watch -n 1 nvidia-smi
python object_detection/export_tflite_ssd_graph.py \ --pipeline_config_path=/home/andy/models/research/object_detection/training/ssd_mobilenet_v1_0.75_depth_quantized_300x300_pets_sync.config \ --trained_checkpoint_prefix=/home/andy/models/research/object_detection/training/model.ckpt-287 \ --output_directory=object_detection/inference_graph_tflite \ --add_postprocessing_op=true
1)下载tensorflow源代码:https://github.com/tensorflow/tensorflow
2)进入源码路径下:/home/light/abaowork/tflite/models/research/tensorflow-master/tensorflow-master 安装bazel https://docs.bazel.build/versions/master/install-ubuntu.html
编译 bazel
build tensorflow/lite/toco:toco
3)转换为tflite格式
./bazel-bin/tensorflow/lite/toco/toco \ --input_file=/home/andy/models/research/object_detection/inference_graph_tflite/tflite_graph.pb \ --output_file=/home/andy/models/research/object_detection/inference_graph_tflite/detect.tflite \ --input_shapes=1,300,300,3 \ --input_arrays=normalized_input_image_tensor \ --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \ --inference_type=QUANTIZED_UINT8 \ --mean_values=128 \ --std_values=128 \ --change_concat_input_ranges=false \ --allow_custom_ops
启动android-studio /data/下载文件/android-studio/bin/studio.sh 下载工程示例https://github.com/tensorflow/examples 屏蔽下载模型程序,替换DetectorActivity.java 模型和标签
手机连接电脑,打开开发者模式,启动usb
运行,最后会在手机端生成app,根据不同的数据集,可以实现不同的检测任务
1.关于tensorflow lite 的介绍:https://tensorflow.google.cn/lite/
2.安卓工程项目demo:https://github.com/tensorflow/examples/blob/master/lite/examples/object_detection/android/README.md
3.转换图像大小和对应标注:https://github.com/italojs/resize_dataset_pascalvoc
4.CSDN关于tensorflow object detection API的介绍https://blog.csdn.net/csdn_6105/article/details/82933628
5.官方教程【重点参考】需要外网:https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193
6.配置object detection API环境:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md
欢迎加我好友,交流学习:微信(abaofight)