计算机视觉 | MMDetection代码实战课

本教程一共包括如下流程:
1. 数据集准备和可视化
2.自定义配置文件
3.训练前可视化验证
4.模型训练
5.模型测试和推理
6.可视化分析

下面我们将以 MMDetection 团队自研的 RTMDet 算法为例,结合一个简单的 cat 数据集来描述整个训练推理可视化过程。

0、环境检测及安装

# 安装 mmengine 和 mmcv 依赖
# 为了防止后续版本变更导致的代码无法运行,我们暂时锁死版本
!pwd
%pip install -U "openmim==0.3.7"
!mim install "mmengine==0.7.1"
!mim install "mmcv==2.0.0"

# Install mmdetection
!rm -rf mmdetection
# 为了防止后续更新导致的可能无法运行,特意新建了 tutorials 分支
!git clone -b tutorials https://github.com/open-mmlab/mmdetection.git
%cd mmdetection

%pip install -e .
from mmengine.utils import get_git_hash
from mmengine.utils.dl_utils import collect_env as collect_base_env

import mmdet

# 环境信息收集和打印
def collect_env():
    """Collect the information of the running environments."""
    env_info = collect_base_env()
    env_info['MMDetection'] = f'{mmdet.__version__}+{get_git_hash()[:7]}'
    return env_info


if __name__ == '__main__':
    for name, val in collect_env().items():
        print(f'{name}: {val}')

1、数据集准备及可视化

我们提供了一个简单的 cat 猫数据集,该数据集来自社区用户,总共包括 144 张图片,并且已经提前划分为了训练集和测试集。

# 数据集可视化

import os
import matplotlib.pyplot as plt
from PIL import Image

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

original_images = []
images = []
texts = []
plt.figure(figsize=(16, 5))

image_paths= [filename for filename in os.listdir('cat_dataset/images')][:8]

for i,filename in enumerate(image_paths):
    name = os.path.splitext(filename)[0]

    image = Image.open('cat_dataset/images/'+filename).convert("RGB")
  
    plt.subplot(2, 4, i+1)
    plt.imshow(image)
    plt.title(f"{filename}")
    plt.xticks([])
    plt.yticks([])

plt.tight_layout()

2、自定义配置文件

本教程采用 RTMDet 进行演示,在开始自定义配置文件前,先来了解下 RTMDet 算法。

RTMDet 是一个高性能低延时的检测算法,目前已经实现了目标检测、实例分割和旋转框检测任务。其简要描述为:**为了获得更高效的模型架构,MMDetection 探索了一种具有骨干和 Neck 兼容容量的架构,由一个基本的构建块构成,其中包含大核深度卷积。MMDetection 进一步在动态标签分配中计算匹配成本时引入软标签,以提高准确性。结合更好的训练技巧,得到的目标检测器名为 RTMDet,在 NVIDIA 3090 GPU 上以超过 300 FPS 的速度实现了 52.8% 的 COCO AP,优于当前主流的工业检测器。RTMDet 在小/中/大/特大型模型尺寸中实现了最佳的参数-准确度权衡,适用于各种应用场景,并在实时实例分割和旋转对象检测方面取得了新的最先进性能。

3、训练前可视化验证

我们可以采用 mmdet 提供的 tools/analysis_tools/browse_dataset.py 脚本来对训练前的 dataloader 输出进行可视化,确保数据部分没有问题。
考虑到我们仅仅想可视化前几张图片,因此下面基于 browse_dataset.py 实现一个简单版本即可。

4、模型训练

python tools/train.py rtmdet_tiny_1xb12-40e_cat.py

5、模型测试和推理

python tools/test.py rtmdet_tiny_1xb12-40e_cat.py work_dirs/rtmdet_tiny_1xb12-40e_cat/best_coco/bbox_mAP_epoch_30.pth

6、可视化分析

可视化分析包括特征图可视化以及类似 Grad CAM 等可视化分析手段。不过由于 MMDetection 中还没有实现,我们可以直接采用 MMYOLO 中提供的功能和脚本。MMYOLO 是基于 MMDetection 开发,并且此案有了统一的代码组织形式,因此 MMDetection 配置可以直接在 MMYOLO 中使用,无需更改配置。

你可能感兴趣的:(计算机视觉,人工智能,深度学习)