使用官方版RT-DETR训练自己的数据集(Pytorch)

一、环境配置

本实验环境配置如下表。

Operating system

 Linux-5.15.0-86-generic-x86_64-with-glibc2.31

CPU

Intel(R) Xeon(R) Gold 5318Y CPU @ 2.10GHz

GPU

NVIDIA RTX A4000

Python version

3.11.5

Deep learning framework

PyTorch2.0.1

CUDA version

11.8

Memory size

16GB

第三方库安装命令:

pypackage需替换为具体的第三方库名称。

pip install pypackage -i https://pypi.tuna.tsinghua.edu.cn/simple

二、源码、数据集准备

2.1 源码获取

RT-DETR 的源码可在GitHub中搜索RT-DETR下载,或通过我的分享链接进行下载:

链接:https://pan.baidu.com/s/1Y4OAcRZue3xv6XpApRDtrw 
提取码:detr 
--来自百度网盘超级会员V4的分享

2.2 数据集制备

RT-DETR训练所需的数据集为COCO格式的数据集,如果你的数据集为YOLO或VOC格式的,则需要进行格式转换。

coco格式的数据集样例如下所示:

其中annotation文件夹下的instances_train2017.json文件存放着训练集所有图像的标注信息,instances_val2017.json文件存放着验证集所有图像的标注信息。

annotations:存储标注信息的文件夹
        instances_train2017.json:训练集的标注文件
        instances_val2017.json:验证集的标注文件
train2017:训练集的图片文件夹
        000000000001.jpg
        000000000002.jpg
        ...
val2017:验证集的图片文件夹
        000000000042.jpg
        000000000071.jpg
        ...

关于yolo格式的数据集转coco格式的,可参考该文章:http://t.csdnimg.cn/eZiTd

三、代码文件的修改

3.1 coco_detection.yml

文件地址:./RT-DETR/RT-DETR-main/rtdetr_pytorch/configs/dataset/coco_detection.yml

修改该yml文件中的img_folder 和 ann_file ,使其分别为你的数据集图像所在目录和标注文件地址

task: detection

num_classes: 80
remap_mscoco_category: False

train_dataloader: 
  type: DataLoader
  dataset: 
    type: CocoDetection
#############################################
    img_folder: /home/guan/RT-DETR/RT-DETR-main/rtdetr_pytorch/COCO/train/
    ann_file: /home/guan/RT-DETR/RT-DETR-main/rtdetr_pytorch/COCO/annotations/train.json
#############################################
    transforms:
      type: Compose
      ops: ~
  shuffle: True
  batch_size: 8
  num_workers: 4
  drop_last: True 


val_dataloader:
  type: DataLoader
  dataset: 
    type: CocoDetection
#############################################
    img_folder: /home/guan/RT-DETR/RT-DETR-main/rtdetr_pytorch/COCO/val/
    ann_file: /home/guan/RT-DETR/RT-DETR-main/rtdetr_pytorch/COCO/annotations/val.json
#############################################
    transforms:
      type: Compose
      ops: ~ 

  shuffle: False
  batch_size: 8
  num_workers: 4
  drop_last: False

3.2 train.py 

代码文件地址:RT-DETR/RT-DETR-main/rtdetr_pytorch/tools/train.py

config参数需根据自己的rtdetr_r18vd_6x_coco.yml文件所在地址设置。

"""by lyuwenyu
"""

import os 
import sys 
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
import argparse

import src.misc.dist as dist 
from src.core import YAMLConfig 
from src.solver import TASKS


def main(args, ) -> None:
    '''main
    '''
    dist.init_distributed()

    assert not all([args.tuning, args.resume]), \
        'Only support from_scrach or resume or tuning at one time'

    cfg = YAMLConfig(
        args.config,
        resume=args.resume, 
        use_amp=args.amp,
        tuning=args.tuning
    )

    solver = TASKS[cfg.yaml_cfg['task']](cfg)
    
    if args.test_only:
        solver.val()
    else:
        solver.fit()


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', type=str,default = "/home/guan/RT-DETR/RT-DETR-main/rtdetr_pytorch/configs/rtdetr/rtdetr_r18vd_6x_coco.yml" )
    parser.add_argument('--resume', '-r', type=str,default =False )
    parser.add_argument('--tuning', '-t', type=str,default = False )
    parser.add_argument('--test-only', action='store_true', default=False,)
    parser.add_argument('--amp', action='store_true', default=False,)

    args = parser.parse_args()

    main(args)

3.3 权重文件

初次训练时会自动下载权重文件,如因报错无法自动下载,可手动下载并将权重文件放入RT-DETR/RT-DETR-main/rtdetr_pytorch/ 文件夹即可。

你可以于本人分享的网盘链接中下载权重文件。

使用官方版RT-DETR训练自己的数据集(Pytorch)_第1张图片

四、执行训练

直接运行train.py或在命令行中输入

python train.py

即可开始执行训练。

训练日志如下图所示。

使用官方版RT-DETR训练自己的数据集(Pytorch)_第2张图片

训练完毕后可进行模型导出和推理,参考:http://t.csdnimg.cn/G59hR

你可能感兴趣的:(pytorch,人工智能,python,深度学习,目标检测)