MMDetection(四):在自己的数据集上训练模型

MMDetection(四):在自己的数据集上训练模型

  • 1. 数据集准备
  • 2. 修改参数
    • 2.1 修改数据集的相关参数
    • 2.2 修改训练相关参数
  • 3. 训练网络
    • 3.1 单GPU训练
    • 3.2 多GPU训练
  • 4. 使用训练结果进行测试并可视化
    • 4.1 验证集图片测试
    • 4.2 训练日志可视化
  • 5. 计算模型复杂度
  • 6. 计算推理速度

1. 数据集准备

本文所使用的数据集为口腔数据集,共有5个类别,分别为:mouth,teeth,tongue,uvula,和oropharynx。
在mmdetection/data文件夹下新建文件夹,将数据集放到这里。

需要将数据集格式转化为VOC或COCO数据集的格式,因为configs里提供的依赖coco的模型较多,因此,我们建议将数据集转化成coco格式。

转化方法参考博客:
Labelme标注的json数据转化为coco格式的数据

2. 修改参数

  • 说明:(1)数据类别有5类:mouth,teeth,tongue,uvula,和oropharynx。
    (2)使用的训练模型是configs/ssd/ssd300_coco.py

2.1 修改数据集的相关参数

官方建议直接修改coco数据集定义文件

  • 修改数据集路径文件: configs/base/datasets/coco_detection.py
    (1)修改data_root为自己数据集的路径;
    (2)修改data字典中train、val、teat相关路径。

  • 修改模型配置文件:configs/base/models/ssd300.py
    (1)修改bbox_head字典中的num_classes为3。

  • 修改coco数据集定义文件:mmdet/datasets/coco.py
    (1)将CLASSES那里的参数修改为:CLASSES = (‘mouth’, ‘teeth’, ‘tongue’, ‘uvula’, ‘oropharynx’)
    (2)将PALETTE参数随意选5个留下即可,例如:PALETTE = [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228)],这个参数用来指定每个类别框的显示颜色。

  • 修改class_name:mmdet/core/evaluation/class_names.py
    (1)定位到coco_classes函数,修改return中的参数为:‘mouth’, ‘teeth’, ‘tongue’, ‘uvula’, ‘oropharynx’

2.2 修改训练相关参数

  • 修改学习率、优化器相关参数:configs/base/schedules/schedule_1x.py
    (1)主要修改学习率lr的值,一般按照线性计算,官方8张GPU设置为0.02,则4张为0.01,2张为0.005
  • 修改其他参数:configs/base/default_runtime.py,例如:
# 保存checkpoints的间隔 默认每次都保存
checkpoint_config = dict(interval=1)
# yapf:disable
# # yapf:disable 打印log的间隔(每个epoch中) 默认迭代50次打印一次(datasets的大小除以batchsize)
log_config = dict(
    interval=50,
    hooks=[
        dict(type='TextLoggerHook'),
        # dict(type='TensorboardLoggerHook')
    ])
# yapf:enable
custom_hooks = [dict(type='NumClassCheckHook')]

dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None    # 加载参数
resume_from = None  # 断点续训 重新加载已训练好的checkpoints 包含epoch等信息 会覆盖load_form
workflow = [('train', 1)]  # 工作流

# disable opencv multithreading to avoid system being overloaded
opencv_num_threads = 0
# set multi-process start method as `fork` to speed up the training
mp_start_method = 'fork'

# Default setting for scaling LR automatically
#   - `enable` means enable scaling LR automatically
#       or not by default.
#   - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=16)

3. 训练网络

3.1 单GPU训练

python tools/train.py configs/ssd/ssd300_coco.py  --gpus 1 

3.2 多GPU训练

CUDA_VISIBLE_DEVICES=2,3 bash tools/dist_train.sh configs/ssd/ssd300_coco.py  2

训练完之后work_dirs文件夹中会保存下训练过程中的log日志文件、每个epoch的pth文件(因为在default_runtime.py设置了checkpoint_config = dict(interval=1)),这个文件将会用于后面的test测试。
MMDetection(四):在自己的数据集上训练模型_第1张图片

4. 使用训练结果进行测试并可视化

4.1 验证集图片测试

pythonpython tools/test.py configs/ssd/ssd300_coco.py work_dirs/ssd300_coco/latest.pth --eval bbox --out work_dirs/ssd300_coco/result.pkl  --show
  • 参数说明:
    config:模型训练的配置文件
    checkpoint:训练结果的权重文件
    –eval:验证指标,一般使用bbox
    –out:测试结果文件保存的路径及名称
    –show:展示每一张验证集图片的测试结果
  • 显示结果如下:
    MMDetection(四):在自己的数据集上训练模型_第2张图片MMDetection(四):在自己的数据集上训练模型_第3张图片并在work_dirs/ssd300_coco文件夹内生成result.pkl 文件

4.2 训练日志可视化

python tools/analysis_tools/analyze_logs.py plot_curve work_dirs/ssd300_coco/20220823_210651.log.json --keys loss_cls loss_bbox

将训练结果20220312_094204.log.json中的参数acc、loss_cls和loss_bbox进行可视化,结果如下
MMDetection(四):在自己的数据集上训练模型_第4张图片

5. 计算模型复杂度

  • 计算复杂度:FLOPs(注意s是小写)
    floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量,和软硬件的配置没有关系,可以公平地用来衡量算法/模型的复杂度。
    计算公式:
    在这里插入图片描述
  • FLOPS(floating point operations per second)
    意指每秒浮点运算次数,理解为计算速度,是一个衡量硬件性能的指标。

在MMDetection中可以使用tools/analysis_tools/get_flops.py命令来获取模型的复杂度:

python tools/analysis_tools/get_flops.py work_dirs/ssd300_coco/ssd300_coco.py 

输出结果显示:
MMDetection(四):在自己的数据集上训练模型_第5张图片

6. 计算推理速度

使用tools/analysis_tools/benchmark.py函数来输出模型的推理速度,注意,mmdet只支持分布式版本,并且它测试的是2000张图片(前500忽略)的平均值。每50张图片显示一次结果

python -m torch.distributed.launch --nproc_per_node=1 --master_port=12345 tools/analysis_tools/benchmark.py work_dirs/ssd300_coco/ssd300_coco.py work_dirs/ssd300_coco/latest.pth --launcher pytorch

运行结果:
MMDetection(四):在自己的数据集上训练模型_第6张图片

你可能感兴趣的:(代码,python,深度学习,人工智能,mmdetection)