安全帽深度学习训练:从TensorFlow Object-Detection API到树莓派 tflite移动端踩坑记

 

今天也算我首次发实现过程博文,写的不好请谅解,如果对您有帮助,麻烦点个赞噢

说一下我的实现过程,我是Object Detection API训练模型,

Tensorflow版本 1.14

操作系统:windows 7  64位+ ubantu 18.04

内存:10G

训练用GPU:1080TI 11G (注:AMD GPU暂时不太清楚,如果自己nvdia gpu低于750TI,显存低于2G,可以用CPU或GPU租借平台,GPU速度大概是CPU10倍)

Python: 3.6

训练框架:ssdlite_mobilenet_v2

 

一、Windows环境配置与安装:

  1. 度娘安装python3.6、pycharm
  2. 度娘安装标准Nvidia套件 CUDA+Cudnn(官网查自己对应版本)
  3. pip install  tensorflow==xx(对应版本号) (CPU版) 或者 pip install  tensorflow-gpu==(xx)(GPU版)   #命令行末尾加
    -i https://pypi.douban.com/simple 换国内豆瓣源,速度快
  4. cmd 下键入python
    import tensorflow as tfl
    hello = tfl.constant('Hello, TensorFlow!')
    sess = tfl.Session()
    print(sess.run(hello))
    
    不报错,出现Hello, TensorFlow!即完成TensorFlow安装
  5. 下载 TensorFlow官方模型库
  6. 安装protuf  (非常容易出错,我这里选择的是win版 3.4)         将bin文件夹中的【protoc.exe】放到C:\Windows 并cmd进入models\research\目录下
    protoc object_detection/protos/*.proto --python_out=.

    不报错即完成,目录下的py文件是29个。(好像只是为了编译py文件,之后我把我的models上传上去,供大家下载)

  7. PYTHONPATH 环境变量设置
    在 ‘此电脑’-‘属性’- ‘高级系统设置’ -‘环境变量’-‘系统变量’ 中新建名为‘PYTHONPATH’的变量,将

    models/research/ 及 models/research/slim 两个文件夹的完整目录添加,分号隔开安全帽深度学习训练:从TensorFlow Object-Detection API到树莓派 tflite移动端踩坑记_第1张图片

    接下来可以测试API,在 models/research/ 文件夹下运行命令行:

    python object_detection/builders/model_builder_test.py

    不报错说明运行成功。

  8. 测试自带案例 如果没问题就可以训练自己的模型了

二、模型训练

严格遵照该过程进行!!!

里面可能需要修改自己文件路径,进行相应改动即可,如果是用别人的数据集,一定要注意对应文件的xml信息及xml文件名是否对应

images1
0.png
C:\Users\White\Desktop\images1\0.png

此外,随着tf不断更新,旧版本的train.py位置一直变化,但都没删,可以在object_detection文件夹里找找,我的是在research\object_detection\legacy里面发现的,网上还有用新版本的model_main.py,各人 感觉是差不多,但糟心的是新版本没有个提醒,训练开始后啥都没有(进度恐惧感你懂得),大神可以自己写个进度信息。

如果显存低,或内存不足,建议调整选用模型的.config  将batch_size调整到1-6, 我是3000张样本,2个标签类11G1080TI,batch_size调为了6,显存使用是9G左右,普通2-4G建议调整到1-3,否则训练一定程度后,容易报显存不足(别问我是怎么知道的。。。),训练步数大概从40000-200000之间。

贴几张成果图

安全帽深度学习训练:从TensorFlow Object-Detection API到树莓派 tflite移动端踩坑记_第2张图片安全帽深度学习训练:从TensorFlow Object-Detection API到树莓派 tflite移动端踩坑记_第3张图片安全帽深度学习训练:从TensorFlow Object-Detection API到树莓派 tflite移动端踩坑记_第4张图片安全帽深度学习训练:从TensorFlow Object-Detection API到树莓派 tflite移动端踩坑记_第5张图片

训练的很到位,识别率棒棒的。

训练识别数最好>=2,如果只训练一类,会导致在没有目标时,出现误检,这也相当于设置负样本

三、树莓派移植

训练模型平台是在Windows,模型转换是在Linux。

在树莓派上安TensorFlow可以正常运行,但速度有点感人,大概40-50s 一张,完全失去了检测意义好吧!!!于是就往tflite上转型。tflite其实是之前谷歌为安卓(亲儿子)量身定制的,只是树莓派挺火,顺便支持一下,最初源码都是外人写的。。。tf只是拿来用用。

首先说个结论:TensorFlow Object-Detection API训练的模型是可以转tflite模型的。这也是本文意义所在

TensorFlow从2018年开始,就以肉眼可见的速度更新着,这一点很棒,与时俱进。但问题也源源不断的出现,就是这一版本的命令,在下一版本改了,或是直接删了(因吹斯听),这就导致给本来不熟悉TensorFlow的人一个k.o。另外,TensorFlow的文档混乱,官方的野生的都有(2.0之后好像好点)。简直就是一入tf,从此头发是路人。。。

吐完槽后,活还是得照干,文档还是得照读,代码还是得照撸。。。

首先就是先在树莓派上搭建TensorFlow lite,这个基本没啥坑,按这个做就行,如果报找不到no module :tflite_runtime,你在导入这个包前加入

import sys
sys.path.append("/home/pi/.local/lib/python3.7/site-packages")

也就是你Python的包位置

最坑的来了,模型转换,干了两三天,终于从坑里爬出来了!!!

我们现在是已经训练出来TensorFlow模型,将TensorFlow的.pb模型转化为tensorflowlite的.tflite模型,而在TensorFlow的改版中,这一块的改动可谓异常巨大。。。现在版本好像又换成命令行了。

网上转模型,一众的bazel 编译,bazel在Windows安装就挺麻烦,编译过程更麻烦,还容易报各种错误,于是我在这一块转到了ubantu下,倒是可以按照这个过程出来,但还是太麻烦了。

其实按我上面训练的模型,根本不用编译!!!不用bazel!!!

       下载 tensorflow源码

在前面,我们训练的结束后的TensorFlow模型会有frozen_inference_graph.pb文件生成,这也是编译下面这行命令想要生成这个文件的工具

bazel build tensorflow/python/tools:freeze_graph

之后在object_detection文件夹中,可找到export_tflite_ssd_graph.py文件

python export_tflite_ssd_graph.py --input_type image_tensor --pipeline_co
nfig_path training/ssd_mobilenet_v1_XXX.config --trained_checkpoint_prefix training/model.ckpt-XXXX --output_directory detection   

即可在输出文件夹中生成tflite_graph.pb和tflite_graph.pbtxt两个文件

下面这行编译命令也是没有必要的

bazel build tensorflow/lite/toco:toco

可以通过(原码和Python 下site-packages包里均有)

python tensorflow/tensorflow/lite/python/tflite_convert.py /
--graph_def_file=xxx/tflite_graph.pb /
--output_format=TFLITE /
--inference_type=QUANTIZED_UINT8 /
--inference_input_type=QUANTIZED_UINT8/
--std_dev_values=128 /
--mean_values=128 /
--input_arrays=normalized_input_image_tensor /
--input_shapes=1,300,300,3 /
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' /
--allow_custom_ops /
--output_file=xxx/ssd.tflite

或是pb_to_tflite.py程序(如下)替代

import tensorflow as tf

in_path = "tflite_graph.pb"

# 模型输入节点
input_tensor_name = ["normalized_input_image_tensor"]
input_tensor_shape = {"normalized_input_image_tensor":[1,300,300,3]}
# 模型输出节点
classes_tensor_name = ['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3']

converter = tf.lite.TFLiteConverter.from_frozen_graph(in_path,
                                            input_tensor_name, classes_tensor_name,
                                            input_tensor_shape)

converter.allow_custom_ops=True
converter.post_training_quantize = True
tflite_model = converter.convert()

open("SSD.tflite", "wb").write(tflite_model)

tf.lite.TFLiteConverter在不同版本下对应代码不同https://www.freesion.com/article/6385314163/

安全帽深度学习训练:从TensorFlow Object-Detection API到树莓派 tflite移动端踩坑记_第6张图片

这里需要注意的是,如果你是按object detect api训练出的模型,可以直接按我的来,只需将input_tensor_shape换成你模型中的值就OK。如果不是,你还需要确认.pb模型输入输出节点array名称和相关矩阵参数

 其他的都不需要,更不需要编译toco,最后通过pb_to_tflite.py生成tflite文件,在踩了无数坑后,大功告成!!!

在模型应用方面推荐一篇博文,没看之前,我一度感觉我转出来的模型是废物。。。这篇博文给我很大帮助

https://blog.csdn.net/qq_39567427/article/details/104057234

我用现在的官网detect_picamera.py和label_image.py等一系列样例,均无法使用,不报错,就是检测不出来

这篇博文里面给了一个非官网的github连接https://github.com/EdjeElectronics/TensorFlow-Lite-Object-Detection-on-Android-and-Raspberry-Pi

里面的TFLite_detection_image.py和训练模型完美适配,识别结果和PC端基本一致,且速度比原来快很多,瞬间出结果。

贴树莓派运行图

安全帽深度学习训练:从TensorFlow Object-Detection API到树莓派 tflite移动端踩坑记_第7张图片安全帽深度学习训练:从TensorFlow Object-Detection API到树莓派 tflite移动端踩坑记_第8张图片

就问你树莓派上600ms左右识别一张图香不香

之前转换出错的原因很多,其中最主要的感觉是TensorFlow版本更新太快,导致文档部件混乱,官方和非官方的都特别多,且有的文档一个版本一个位置,这就造成很多教程无法复制性。哈哈,选择了TensorFlow,大家就享受踩坑之旅吧。

参考链接

https://blog.csdn.net/dy_guox/article/details/79111949

https://www.pythonf.cn/read/3900

https://www.cnblogs.com/White-xzx/p/9503203.html

https://blog.csdn.net/SpiritYzw/article/details/105629397?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-10.nonecase&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-10.nonecase

https://github.com/EdjeElectronics/TensorFlow-Lite-Object-Detection-on-Android-and-Raspberry-Pi

https://github.com/tensorflow/tensorflow

你可能感兴趣的:(深度学习,tensorflow)