使用pytorch进行深度学习网络模型训练,实现车型识别


简介

通过深度学习技术搭建残差网络,使用 CompsCars数据集进行车型识别模型的训练,并将训练好的模型移植到了Android端,实现了通过手机扫一扫的方式进行汽车车型识别的功能。

项目涉及到的技术点较多,需要开发者有一定的技术功底。如:python语言的使用、深度学习框架pytorch的使用、爬虫脚本的理解、Java语言的使用、Android平台架构的理解等等。

虽然属于跨语言开发,但是要求并不高,只要达到入门级别即可看懂本项目,并可以尝试一些定制化的改造。毕竟框架已经搭建好了,只需要修改数据源、重新训练出模型,就可以实现一款新的应用啦。

模型训练精度

以下是使用Resnet-34进行400次车型识别训练的 train-validation图表。

使用pytorch进行深度学习网络模型训练,实现车型识别_第1张图片

以下是使用Resnet-34进行400次车型识别训练 Top-1的错误率。

使用pytorch进行深度学习网络模型训练,实现车型识别_第2张图片


以下是使用Resnet-34进行400次车型识别训练 Top-5的错误率。

扫一扫识别功能

以下是移植到android平台后进行识别的结果展示图。

使用pytorch进行深度学习网络模型训练,实现车型识别_第3张图片

使用的技术&框架

  • 开发语言:Python、Java
  • 技术框架:pytorch、resnet-34、Android平台
  • 可选借助平台:百度AI平台
  • 项目构成:模型训练项目、爬虫项目、APP开发项目

软/硬件需求

机器要求

因为涉及到机器学习模型训练,所以你应该拥有一台用来训练模型的机器,且需要搭载支持CUDA的GPU(如:GeForce、GTX、Tesla等),显存大小,自然是越大越好。

本人项目环境:

  • windows10 专业版;GeForce MAX150;独显 2G;1T硬盘

也就是说这是最低配了,你至少要和我同一配置。

开发工具

  • Pycharm:用来训练模型、pyhton爬虫、模型移植脚本
  • Android Studio:用来开发安卓APP

数据集

数据集是项目最重要的一部分,有了数据集才能开始训练

本项目使用的是

训练模型主要分为五个模块:启动器、自定义数据加载器、网络模型、学习率/损失率调整以及训练可视化。

启动器是项目的入口,通过对启动器参数的设置,可以进行很多灵活的启动方式,下图为部分启动器参数设置。

使用pytorch进行深度学习网络模型训练,实现车型识别_第4张图片

任何一个深度学习的模型训练都是离不开数据集的,根据多种多样的数据集,我们应该使用一个方式将数据集用一种通用的结构返回,方便网络模型的加载处理。

使用pytorch进行深度学习网络模型训练,实现车型识别_第5张图片

这里使用了残差网络Resnet-34,代码中还提供了Resnet-18、Resnet-50、Resnet-101以及Resnet-152。残差结构是通过一个快捷连接,极大的减少了参数数量,降低了内存使用。

以下为残差网络的基本结构和Resnet-34 部分网络结构图。

使用pytorch进行深度学习网络模型训练,实现车型识别_第6张图片

使用pytorch进行深度学习网络模型训练,实现车型识别_第7张图片
除了最开始看到的train-val图表、Top-、Top-5的error记录表以外,在训练过程中,使用进度条打印当前训练的进度、训练精度等信息。打印时机可以通过上边提到的 启动器 优雅地配置。

使用pytorch进行深度学习网络模型训练,实现车型识别_第8张图片


以下为最终的项目包架构。

 
  
  1. pytorch_train
  2. |-- data -- 存放读取训练、校验、测试数据路径的txt
  3. | |-- train.txt
  4. | |-- val.txt
  5. | |-- test.txt
  6. |-- result -- 存放最终生成训练结果的目录
  7. |-- util -- 模型移植工具
  8. |-- clr.py -- 学习率
  9. |-- dataset.py -- 自定义数据集
  10. |-- flops_benchmark.py -- 统计每秒浮点运算次数
  11. |-- logger.py -- 日志可视化
  12. |-- mobile_net.py -- 网络模型之一 mobile_net2
  13. |-- resnet.py -- 网络模型之一 Resnet系列
  14. |-- run.py -- 具体执行训练、测试方法
  15. |-- start.py -- 启动器

使用pytorch进行深度学习网络模型训练,实现车型识别_第9张图片

你可能感兴趣的:(深度学习,pytorch,python)