【MMDetection】v2.22.0入门:训练自己的数据集

文章目录

    • 一、MMDetection安装并测试
      • 安装步骤
    • 二、数据集准备
    • 三、训练前的准备(修改参数)
      • 事先说明
      • 修改数据集相关参数
      • 修改训练相关参数
    • 四、开始训练
      • 单GPU训练
      • 指定多GPU训练
    • 五、使用训练结果进行测试并可视化
      • 验证集图片测试
      • 训练日志可视化
      • DetVisGUI可视化
      • 计算模型复杂度
    • References

一、MMDetection安装并测试

MMDetection是MMLab家族的一员,是由香港中文大学和商汤科技共同推出的,以一个统一的架构支撑了15个大方向的研究领域,实现了1300+的算法复现工作。MMDetection依赖Pytorch和MMCV(mmcv/mmcv-full),因此安装之前需要先安装这两个库,具体安装步骤可参考这篇博客:MMDetection框架入门教程(一):Anaconda3下的安装教程(mmdet+mmdet3d)。

安装步骤

  1. Anaconda虚拟环境搭建
    1. conda create -n mmdet python=3.8
    2. conda activate mmdet
  2. Pytorch安装
    1. nvidia-smi确定服务器中CUDA的版本
    2. conda install pytorch==1.9.1 torchvision torchaudio cudatoolkit=10.2 -c pytorch安装对应版本的torch
    3. torch.cuda.is_available验证torch是否安装成功
  3. 安装MMCV
  4. 安装MMDetection
  5. 使用Demo验证是否安装成功
    1. mmdetection/新建checkpoints文件夹,下载faster-rcnn的预训练模型权重到该文件夹下
    2. mmdetection/新建test_demo.py文件,输入以下代码,然后运行
# 测试mmdet、mmcv是否安装成功
from mmdet.apis import init_detector, inference_detector, show_result_pyplot
import torch

config_file = 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = init_detector(config_file, checkpoint_file, device=device)
img = 'demo/demo.jpg'
result = inference_detector(model, img)
print(len(result))
show_result_pyplot(model=model, img=img, result=result, score_thr=0.9)

若安装成功,则运行结果如下图所示

【MMDetection】v2.22.0入门:训练自己的数据集_第1张图片

 

二、数据集准备

本文所使用的数据集为三种水果数据集,下载链接为:https://download.csdn.net/download/weixin_43799388/84425688。在mmdetection/新建data文件夹,将数据集解压后放到这里。

数据集文档结构如下:

|——-Fruit
	|---Annotations
		|---001.xml
		|---002.xml
		... ...
		|---340.xml
	|---images
		|---001.jpg
		|---002.jpg
		... ...
		|---340.jpg

MMDetection一共支持三种形式应用新数据集:

  1. 将数据集重新组织为 COCO 格式
  2. 将数据集重新组织为一个中间格式
  3. 实现一个新的数据集

官方建议使用前面两种方法,因为它们通常来说比第三种方法要简单。

该数据集的标注形式为xml格式,其中GT框的坐标信息是以左上角坐标(xmin, ymin)和右下角坐标(xmax, ymax)形式来标注的,这里我们采用MMDetection建议的第一种方法,将数据集重新组织为 COCO 格式,具体转换步骤可以参考我写的另一篇博客:VOC(xml)标注格式转换为YOLOv5(txt)和COCO2017(json)格式。

转换结果的文档结构如下所示,我们真正用到的是annotations、train和val这三个文件夹。

|——-Fruit
	|---annotations
		|---Fruit_train.json
		|---Fruit_val.json
	|---Annotations
		|---001.xml
		|---002.xml
		... ...
		|---340.xml
	|---images
		|---001.jpg
		|---002.jpg
		... ...
		|---340.jpg
    |---ImageSets
		|---train.txt  # 存放训练集图片名称
		|---val.txt  # 存放验证集图片名称
	|---train
		|---001.jpg
		|---002.jpg
		... ...
	|---val
		|---046.jpg
		|---049.jpg
		... ...

 

三、训练前的准备(修改参数)

事先说明

  • 数据集类别一共有3个:apple/banana/grape
  • 使用的训练模型是configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py

 

修改数据集相关参数

虽然我们已经将xml格式标注的数据集转换成了MMDetection使用的COCO数据集格式,但还需要修改一些配置参数(官方建议直接修改coco数据集定义文件):

  1. 修改模型配置文件:configs/_base_/models/faster_rcnn_r50_fpn.py
    1. 定位到roi_head字典出,修改bbox_head字典中的num_classes为3
  2. 修改coco数据集定义文件:mmdet/datasets/coco.py
    1. CLASSES那里的参数修改为:CLASSES = ('apple', 'banana', 'grape')
    2. PALETTE参数随意选三个留下即可,例如:PALETTE = [(220, 20, 60), (119, 11, 32), (191, 162, 208)],这个参数用来指定每个类别框的显示颜色
  3. 修改class_name:mmdet/core/evaluation/class_names.py
    1. 定位到coco_classes函数,修改return中的参数为:'apple', 'banana', 'grape'
  4. mmdetection目录下新建test_work_dirs文件夹

 

修改训练相关参数

此外,在configs/_base_/default_runtime.py文件中可以修改训练时的其他参数,本次训练的default_runtime.py代码如下:

# 保存checkpoints的间隔 默认每次都保存
checkpoint_config = dict(interval=4)
# yapf:disable 打印log的间隔(每个epoch中) 默认迭代50次打印一次(datasets的大小除以batchsize)
log_config = dict(
    interval=5,
    hooks=[
        dict(type='TextLoggerHook'),
        # dict(type='TensorboardLoggerHook')
    ])
# yapf:enable
custom_hooks = [dict(type='NumClassCheckHook')]
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None  # 加载参数
# 断点续训 重新加载已训练好的checkpoints 包含epoch等信息 会覆盖load_form
resume_from = None
# 工作流
workflow = [('train', 1)]
# disable opencv multithreading to avoid system being overloaded
opencv_num_threads = 0
# set multi-process start method as `fork` to speed up the training
mp_start_method = 'fork'

 

四、开始训练

单GPU训练

python tools/train.py configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py --gpus 1 --work-dir test_work_dirs

 

指定多GPU训练

CUDA_VISIBLE_DEVICES=2,3指定使用GPU-3和GPU-4,放到python命令之前,同时要设置--gpus 2

CUDA_VISIBLE_DEVICES=2,3 python tools/train.py configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py --gpus 2 --work-dir xiaolong_dir0

训练过程界面如下:

训练完之后test_work_dirs文件夹中会保存下训练过程中的log日志文件、每4个epoch的pth文件(因为在default_runtime.py设置了checkpoint_config = dict(interval=4)),这个文件将会用于后面的test测试。

 

五、使用训练结果进行测试并可视化

验证集图片测试

python tools/test.py test_work_dirs/faster_rcnn_r50_fpn_1x_coco.py test_work_dirs/epoch_12.pth --eval bbox --out test_work_dirs/result12.pkl --show 

传递参数说明:

  • config:模型训练的配置文件
  • checkpoint:训练结果的权重文件
  • --eval:验证指标,一般使用bbox
  • --out:测试结果文件保存的路径及名称
  • --show:展示每一张验证集图片的测试结果

测试结果如下:

【MMDetection】v2.22.0入门:训练自己的数据集_第2张图片

 

训练日志可视化

python tools/analysis_tools/analyze_logs.py plot_curve test_work_dirs/20220312_094204.log.json --keys acc loss_cls loss_bbox

将训练结果20220312_094204.log.json中的参数acc、loss_clsloss_bbox进行可视化,结果如下(由于数据集太小,并且只有3个class,因此训练很快就收敛了):

【MMDetection】v2.22.0入门:训练自己的数据集_第3张图片

 

DetVisGUI可视化

DetVisGUI工具是一个用于可视化MMDetection测试结果的轻量级GUI,它可以动态显示不同阈值的检测结果,便于验证检测结果和GT框的差异。

  • 在使用DetVisGUI工具之前,需要利用python tools/test.py命令,保存测试的结果文件(pkl格式或者json格式)

  • 之后从GitHub上下载DetVisGUI源码,放在mmdetection/DetVisGUI/文件夹中,运行DetDetVisGUI的命令格式是:

python DetVisGUI/DetVisGUI.py ${CONFIG_FILE} [--det_file ${RESULT_FILE}] [--stage ${STAGE}] [--output ${SAVE_DIRECTORY}]

传递参数说明:

  • config:模型训练的配置文件
  • --det_file:测试结果文件(pkl格式或json格式)
  • --stage:测试结果文件的三种stage(train/val/test),默认为val
  • --output:测试图片的保存路径,默认为output

运行如下命令:

python DetVisGUI/DetVisGUI.py test_work_dirs/faster_rcnn_r50_fpn_1x_coco.py --det_file test_work_dirs/result12.pkl --output test_work_dirs/val_result

运行结果如下图所示,可以调节IOU阈值、置信度阈值、是否显示GT框及文本信息等,来对比测试结果。

【MMDetection】v2.22.0入门:训练自己的数据集_第4张图片

点击左上角的Save All Results按钮,即可将全部验证集图片的测试结果保存到指定路径中。

 

计算模型复杂度

首先来科普一下FLOPs和FLOPS的区别:

  • 计算复杂度:FLOPs(注意s是小写)
    • floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量,和软硬件的配置没有关系,可以公平地用来衡量算法/模型的复杂度
    • 计算公式: F L O P s = C o u t ∗ H o u t ∗ W o u t ∗ C i n ∗ k ∗ k FLOPs=C_{out}*H_{out}*W_{out}*C_{in}*k*k FLOPs=CoutHoutWoutCinkk
  • FLOPS(floating point operations per second)
    • 意指每秒浮点运算次数,理解为计算速度,是一个衡量硬件性能的指标

在MMDetection中可以使用tools/analysis_tools/get_flops.py命令来获取模型的复杂度:

python tools/analysis_tools/get_flops.py test_work_dirs/faster_rcnn_r50_fpn_1x_coco.py

运行结果如下所示,可以看到get_flops.py函数会打印出模型每一层的FLOPs和参数量,以及总的FLOPs和参数量。
但要注意的是,这些数据仅供参考,不一定准确,输出结果的最后一行也提醒我们,不建议把这个输出结果放到论文中。

【MMDetection】v2.22.0入门:训练自己的数据集_第5张图片

 

References

MMDetection框架入门教程(一):Anaconda3下的安装教程(mmdet+mmdet3d)

官方教程:MMDETECTION’S DOCUMENTATION!

VOC(xml)标注格式转换为YOLOv5(txt)和COCO2017(json)格式

【mmdetection】使用自定义的coco格式数据集进行训练及测试

mmdetection可视化工具-DetVisGUI

你可能感兴趣的:(目标检测,mmdetection,深度学习,目标检测)