快速上手mmdetection训练自己的数据集

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • 前言
  • 一、MMdetection目录结构
  • 二、使用步骤
    • 1.准备数据集
    • 2.训练自己的数据集
    • 3.训练
  • 总结


前言

mmdetection学习记录


一、MMdetection目录结构

mmdetection-master
--configs
--demo
--docker
--docs
--mmdet
--requirements
--tests
--tools

configs包含了各种网络的配置文件,mmdetection2.0采用了继承机制,大多数模型都是继承_base_中的models,只是更换主干网络。其中_base_中还包括datasets,schedules可以定义数据集和训练策略。
mmdet包含整个项目最重要的部分,apis存放训练,验证,推理部分的函数,core存放了Anchor和bbox相关以及后处理等操作。datasets存放了各个数据集的处理代码。models存放检测模型的backbone,neck,head,以及损失函数定义。ops存放DCN,NMS,ROIPOOLING等优化算法。utils则是各种工具函数。

二、使用步骤

1.准备数据集

首先在主目录下创建data文件夹,mmdetection支持coco,voc,cityscapes,以及自定义类型数据集,数据集组织格式如下:
├── data
│ ├── coco
│ │ ├── annotations
│ │ ├── train2017
│ │ ├── val2017
│ │ ├── test2017
│ ├── cityscapes
│ │ ├── annotations
│ │ ├── leftImg8bit
│ │ │ ├── train
│ │ │ ├── val
│ │ ├── gtFine
│ │ │ ├── train
│ │ │ ├── val
│ ├── VOCdevkit
│ │ ├── VOC2007
│ │ ├── VOC2012

xmlfilepath=r'./VOCdevkit/VOC2007/Annotations'
saveBasePath=r"./VOCdevkit/VOC2007/ImageSets/Main/"
 
trainval_percent=0.8
train_percent=0.8

temp_xml = os.listdir(xmlfilepath)
total_xml = []
for xml in temp_xml:
    if xml.endswith(".xml"):
        total_xml.append(xml)

num=len(total_xml)  
list=range(num)  
tv=int(num*trainval_percent)  
tr=int(tv*train_percent)  
trainval= random.sample(list,tv)  
train=random.sample(trainval,tr)  
 
print("train and val size",tv)
print("traub suze",tr)
ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w')  
ftest = open(os.path.join(saveBasePath,'test.txt'), 'w')  
ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w')  
fval = open(os.path.join(saveBasePath,'val.txt'), 'w')  
 
for i  in list:  
    name=total_xml[i][:-4]+'\n'  
    if i in trainval:  
        ftrainval.write(name)  
        if i in train:  
            ftrain.write(name)  
        else:  
            fval.write(name)  
    else:  
        ftest.write(name)  
  
ftrainval.close()  
ftrain.close()  
fval.close()  
ftest .close()

用来分割train,val,test.txt的代码

2.训练自己的数据集

这里做示例的是voc格式的数据集,需要修改四个地方。
1./configs/base/datasets/voc0712.py

data = dict(
    samples_per_gpu=2,#batchsize设置
    workers_per_gpu=2,
    train=dict(
        type='RepeatDataset',
        times=3,
        dataset=dict(
            type=dataset_type,
            ann_file=[
                data_root + 'VOC2007/ImageSets/Main/trainval.txt',
                #data_root + 'VOC2012/ImageSets/Main/trainval.txt'
            ],
            img_prefix=[data_root + 'VOC2007/', 
            #data_root + 'VOC2012/'],
            pipeline=train_pipeline)),
    val=dict(
        type=dataset_type,
        ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
        img_prefix=data_root + 'VOC2007/',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
        img_prefix=data_root + 'VOC2007/',
        pipeline=test_pipeline))
evaluation = dict(interval=1, metric='mAP')

因为voc2007与voc2012是互斥的,我使用的是voc2007格式,这里需要注释掉voc2012的相关内容,同理,如果使用voc2012也需要注释掉voc2007.
2./mmdet/datasets/voc.py

class VOCDataset(XMLDataset):

   ''' CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
               'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
               'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
               'tvmonitor')'''
      CLASSES = ('aeroplane', )

注意如果是单类的话后面也需要加逗号。
3./mmdet/core/evaluation/class_names.py

def voc_classes():
    '''return [
        'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
        'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
        'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
    ]'''
	 return [
        'aeroplane'
    ]

这里修改验证时的类别数,如果不修改训练时不会出错,但是验证时会报错。
4.修改configs中模型的类别数,以faster_rcnn为例。
/configs/base/models/faster_rcnn_r50_fpn.py
搜索num_classes,默认是80,这是coco的类别数,改成自己的类别数,不需要背景+1。例如cascade_rcnn这样的有三个num_classes都需要修改。

3.训练

这里训练还是以faster_rcnn_r50_fpn_1x_coco.py为例

python tools/train.py configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py

测试

python tools/test.py configs/faster_rcnn_r50_fpn_1x.py work_dirs/epoch_100.pth --out ./result/result_100.pkl --eval bbox --show
 

总结

这里总结了mmdetection如何训练自己的数据集,后面将会更新具体的调参细节。

你可能感兴趣的:(深度学习)