本实验环境配置如下表。
Operating system |
Linux-5.15.0-86-generic-x86_64-with-glibc2.31 |
CPU |
Intel(R) Xeon(R) Gold 5318Y CPU @ 2.10GHz |
GPU |
NVIDIA RTX A4000 |
Python version |
3.11.5 |
Deep learning framework |
PyTorch2.0.1 |
CUDA version |
11.8 |
Memory size |
16GB |
第三方库安装命令:
pypackage需替换为具体的第三方库名称。
pip install pypackage -i https://pypi.tuna.tsinghua.edu.cn/simple
RT-DETR 的源码可在GitHub中搜索RT-DETR下载,或通过我的分享链接进行下载:
链接:https://pan.baidu.com/s/1Y4OAcRZue3xv6XpApRDtrw
提取码:detr
--来自百度网盘超级会员V4的分享
RT-DETR训练所需的数据集为COCO格式的数据集,如果你的数据集为YOLO或VOC格式的,则需要进行格式转换。
coco格式的数据集样例如下所示:
其中annotation文件夹下的instances_train2017.json文件存放着训练集所有图像的标注信息,instances_val2017.json文件存放着验证集所有图像的标注信息。
annotations:存储标注信息的文件夹
instances_train2017.json:训练集的标注文件
instances_val2017.json:验证集的标注文件
train2017:训练集的图片文件夹
000000000001.jpg
000000000002.jpg
...
val2017:验证集的图片文件夹
000000000042.jpg
000000000071.jpg
...
关于yolo格式的数据集转coco格式的,可参考该文章:http://t.csdnimg.cn/eZiTd
文件地址:./RT-DETR/RT-DETR-main/rtdetr_pytorch/configs/dataset/coco_detection.yml
修改该yml文件中的img_folder 和 ann_file ,使其分别为你的数据集图像所在目录和标注文件地址
task: detection
num_classes: 80
remap_mscoco_category: False
train_dataloader:
type: DataLoader
dataset:
type: CocoDetection
#############################################
img_folder: /home/guan/RT-DETR/RT-DETR-main/rtdetr_pytorch/COCO/train/
ann_file: /home/guan/RT-DETR/RT-DETR-main/rtdetr_pytorch/COCO/annotations/train.json
#############################################
transforms:
type: Compose
ops: ~
shuffle: True
batch_size: 8
num_workers: 4
drop_last: True
val_dataloader:
type: DataLoader
dataset:
type: CocoDetection
#############################################
img_folder: /home/guan/RT-DETR/RT-DETR-main/rtdetr_pytorch/COCO/val/
ann_file: /home/guan/RT-DETR/RT-DETR-main/rtdetr_pytorch/COCO/annotations/val.json
#############################################
transforms:
type: Compose
ops: ~
shuffle: False
batch_size: 8
num_workers: 4
drop_last: False
代码文件地址:RT-DETR/RT-DETR-main/rtdetr_pytorch/tools/train.py
config参数需根据自己的rtdetr_r18vd_6x_coco.yml文件所在地址设置。
"""by lyuwenyu
"""
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
import argparse
import src.misc.dist as dist
from src.core import YAMLConfig
from src.solver import TASKS
def main(args, ) -> None:
'''main
'''
dist.init_distributed()
assert not all([args.tuning, args.resume]), \
'Only support from_scrach or resume or tuning at one time'
cfg = YAMLConfig(
args.config,
resume=args.resume,
use_amp=args.amp,
tuning=args.tuning
)
solver = TASKS[cfg.yaml_cfg['task']](cfg)
if args.test_only:
solver.val()
else:
solver.fit()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c', type=str,default = "/home/guan/RT-DETR/RT-DETR-main/rtdetr_pytorch/configs/rtdetr/rtdetr_r18vd_6x_coco.yml" )
parser.add_argument('--resume', '-r', type=str,default =False )
parser.add_argument('--tuning', '-t', type=str,default = False )
parser.add_argument('--test-only', action='store_true', default=False,)
parser.add_argument('--amp', action='store_true', default=False,)
args = parser.parse_args()
main(args)
初次训练时会自动下载权重文件,如因报错无法自动下载,可手动下载并将权重文件放入RT-DETR/RT-DETR-main/rtdetr_pytorch/ 文件夹即可。
你可以于本人分享的网盘链接中下载权重文件。
直接运行train.py或在命令行中输入
python train.py
即可开始执行训练。
训练日志如下图所示。
训练完毕后可进行模型导出和推理,参考:http://t.csdnimg.cn/G59hR