linux配置mmdetection2.8训练自定义coco数据集(一)

文章目录

  • 前言
  • 1、安装环境
  • 2、mmcv和mmdet安装步骤
  • 3、voc转coco脚本
    • 3.1 文件夹准备
    • 3.2 运行转换脚本
    • 3.3 整理文件夹
  • 4. 运行faster-rcnn
    • 4.1 配置步骤
    • 4.2 运行命令
    • 4.3 运行截图
  • 总结
  • 参考资料


前言

网上安装教程挺多的。这里推荐还是源码安装。因为mmcv提供了test的demo,


1、安装环境

torch == 1.7.0
torchvision == 0.8.0
mmcv == 1.2.6
mmdetection == 2.8.0
CUDA == 10.1

2、mmcv和mmdet安装步骤

git clone https://github.com/open-mmlab/mmcv.git
cd mmcv
MMCV_WITH_OPS=1 pip install -e .  # 安装full版本
# install mmdetection
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
pip install -r requirements/build.txt
pip install -v -e .

3、voc转coco脚本

3.1 文件夹准备

 由于mmdetection中大多数模型实质上仅支持coco数据集,即coco数据集应用更加广泛。因此,这里我们首先制作一个coco数据集。这里采用的是将voc2007转成coco。同时数量也比较小。这里贴出转换脚本:
 说下数据集摆放格式,即需要在VOC2007文件夹同级下新建三个空文件夹(train2017,val2017和annotations)。
linux配置mmdetection2.8训练自定义coco数据集(一)_第1张图片

3.2 运行转换脚本

 运行下面代码 【108行代码需要改成自己路径!!!】

import os
import os.path as osp
import json
import argparse
import xml.etree.ElementTree as ET


START_BOUNDING_BOX_ID = 1
PRE_DEFINE_CATEGORIES = {
     }


def get(root, name):
    vars = root.findall(name)
    return vars


def get_and_check(root, name, length):
    vars = root.findall(name)
    if len(vars) == 0:
        raise NotImplementedError('Can not find %s in %s.' % (name, root.tag))
    if 0 < length != len(vars):
        raise NotImplementedError('The size of %s is supposed to be %d, but is %d.' % (name, length, len(vars)))
    if length == 1:
        vars = vars[0]
    return vars


def get_filename_as_int(filename):
    try:
        filename = os.path.splitext(filename)[0]
        return int(filename)
    except:
        raise NotImplementedError('Filename %s is supposed to be an integer.' % (filename))


def _convert(xml_list, xml_dir, json_file):
    if isinstance(xml_list, list):
        list_fps = []
        for xml in xml_list:
            list_fps.append(open(xml, 'r'))
    else:
        list_fps = [open(xml_list, 'r')]
        xml_dir = [xml_dir]

    json_dict = {
     "images": [], "type": "instances", "annotations": [],
                 "categories": []}
    categories = PRE_DEFINE_CATEGORIES
    bnd_id = START_BOUNDING_BOX_ID
    for i, lines in enumerate(list_fps):
        for line in lines:
            line = line.strip()
            print("Processing %s" % (line + '.xml'))
            xml_f = os.path.join(xml_dir[i], line + '.xml')
            flag_name = xml_dir[i].split('/')[-2] + '/JPEGImages'
            tree = ET.parse(xml_f)
            root = tree.getroot()
            path = get(root, 'path')
            if len(path) == 1:
                filename = os.path.basename(path[0].text)
            elif len(path) == 0:
                filename = get_and_check(root, 'filename', 1).text
            else:
                raise NotImplementedError('%d paths found in %s' % (len(path), line))

            image_id = get_filename_as_int(filename)
            size = get_and_check(root, 'size', 1)
            width = int(get_and_check(size, 'width', 1).text)
            height = int(get_and_check(size, 'height', 1).text)
            #image = {'file_name': os.path.join(flag_name, filename), 'height': height, 'width': width,'id': image_id}
            image = {
     'file_name': filename, 'height': height, 'width': width, 'id': image_id}
            json_dict['images'].append(image)
            for obj in get(root, 'object'):
                category = get_and_check(obj, 'name', 1).text
                if category not in categories:
                    new_id = len(categories)
                    categories[category] = new_id
                category_id = categories[category]
                bndbox = get_and_check(obj, 'bndbox', 1)
                xmin = int(get_and_check(bndbox, 'xmin', 1).text) - 1
                ymin = int(get_and_check(bndbox, 'ymin', 1).text) - 1
                xmax = int(get_and_check(bndbox, 'xmax', 1).text)
                ymax = int(get_and_check(bndbox, 'ymax', 1).text)
                assert (xmax > xmin)
                assert (ymax > ymin)
                o_width = abs(xmax - xmin)
                o_height = abs(ymax - ymin)
                ann = {
     'area': o_width * o_height, 'iscrowd': 0, 'image_id':
                    image_id, 'bbox': [xmin, ymin, o_width, o_height],
                       'category_id': category_id, 'id': bnd_id, 'ignore': 0,
                       'segmentation': []}
                json_dict['annotations'].append(ann)
                bnd_id = bnd_id + 1

    for cate, cid in categories.items():
        cat = {
     'supercategory': 'none', 'id': cid, 'name': cate}
        json_dict['categories'].append(cat)
    json_fp = open(json_file, 'w')
    json_str = json.dumps(json_dict)
    json_fp.write(json_str)
    json_fp.close()
    for lines in list_fps:
        lines.close()


def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert PASCAL VOC annotations to coco format')
    parser.add_argument('--devkit_path',default='/home/wujian/VOCdevkit/',help='pascal voc devkit path')  # voc根路径 里面存放的是VOC2007和VOC2012两个子文件夹
    parser.add_argument('-o', '--out-dir', help='output path')  # annotations 保存文件夹
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    devkit_path = args.devkit_path
    out_dir = args.out_dir if args.out_dir else devkit_path
    #cv_core.mkdir_or_exist(out_dir)

    year = None
    years = []
    if osp.isdir(osp.join(devkit_path, 'VOC2007')):
        year = '2007'
        years.append(year)
    if osp.isdir(osp.join(devkit_path, 'VOC2012')):
        year = '2012'
        years.append(year)
    if '2007' in years and '2012' in years:
        year = ['2007', '2012']

    if year == '2007':
        prefix = 'voc07'
        split = ['trainval', 'test']  # train集和test集
    elif year == '2012':
        prefix = 'voc12'
        split = ['train', 'val']  # train集和test集
    elif year == ['2007', '2012']:
        prefix = 'voc0712'
        split = [['trainval', 'train'], ['test', 'val']]  # train集和test集
    else:
        raise NotImplementedError

    for split_ in split:
        if isinstance(split_, list):
            dataset_name = prefix + '_' + split_[0]
        else:
            dataset_name = prefix + '_' + split_
        print('processing {} ...'.format(dataset_name))
        annotations_path = osp.join(out_dir, 'annotations')
        #cv_core.mkdir_or_exist(annotations_path)
        out_file = osp.join(annotations_path, dataset_name + '.json')
        if isinstance(split_, list):
            filelists = []
            xml_dirs = []
            for i, s in enumerate(split_):
                filelist = osp.join(devkit_path,
                                    'VOC{}/ImageSets/Main/{}.txt'.format(year[i], s))
                xml_dir = osp.join(devkit_path, 'VOC{}/Annotations'.format(year[i]))
                filelists.append(filelist)
                xml_dirs.append(xml_dir)
        else:
            filelists = osp.join(devkit_path, 'VOC{}/ImageSets/Main/{}.txt'.format(year, split_))
            xml_dirs = osp.join(devkit_path, 'VOC{}/Annotations'.format(year))
        _convert(filelists, xml_dirs, out_file)

    print('Done!')


if __name__ == '__main__':
    main()

3.3 整理文件夹

 运行完后在annotations文件夹内就存储了coco格式标注文件,然后将两个json文件分别改成:instances_train2017.jsoninstances_val2017.json
 然后将VOC中的JPG依据trainval.txt和test.txt索引将图像分别存储进train2017和val2017文件夹内部。这里我简单写了个垃圾脚本,轻喷。

import cv2

root_path = '/home/wujian/VOCdevkit/VOC2007/'         # voc路径
trainval_path = root_path + "ImageSets/Main/test.txt" # 改成 trainval.txt / test.txt
jpg_path = root_path + 'JPEGImages/'
save_path = "/home/wujian/VOCdevkit/val2017/"         # 改成 train2017 / val2017

with open(trainval_path,'r') as f:
    for ele in f.readlines():
        cur_jpgname = ele.strip()  # 提取当前图像的文件名
        total_jpgname = jpg_path + cur_jpgname + '.jpg' # 获取图像全部路径
        # 读取图像
        cur_img = cv2.imread(total_jpgname)
        # 保存图像
        cv2.imwrite(save_path + cur_jpgname + '.jpg',cur_img)
    print('Done!')

  okay,然后整理成如下格式:
linux配置mmdetection2.8训练自定义coco数据集(一)_第2张图片

4. 运行faster-rcnn

4.1 配置步骤

  (1)新建一个data文件夹,里面存放转换好的coco数据集。
linux配置mmdetection2.8训练自定义coco数据集(一)_第3张图片
 (2)修改mmdet/datasets/coco.py,将CLASSES修改如下:

@DATASETS.register_module()
class CocoDataset(CustomDataset):
    '''
    CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
               'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
               'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
               'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
               'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
               'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
               'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
               'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
               'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
               'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
               'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
               'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
               'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
               'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
    '''
    CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
    'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
    'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train','tvmonitor')

 (3)在mmdet/core/evaluation/class_name.py中,修改如下:

'''
def coco_classes():
    return [
        'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
        'truck', 'boat', 'traffic_light', 'fire_hydrant', 'stop_sign',
        'parking_meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
        'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
        'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
        'sports_ball', 'kite', 'baseball_bat', 'baseball_glove', 'skateboard',
        'surfboard', 'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork',
        'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
        'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake', 'chair',
        'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv',
        'laptop', 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
        'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
        'scissors', 'teddy_bear', 'hair_drier', 'toothbrush'
    ]
'''
def coco_classes():
    return [
        'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
        'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
        'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
    ]

 (4) 在configs/_base_/models/faster_rcnn_r50_fpn.py中将num_classes = 20。原来是80。
 (5) 在configs/_base_ /datasets/coco_detection.py中将data_root改成绝对路径

4.2 运行命令

  之后,就通过 指令:

python tools/train.py configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py

4.3 运行截图

linux配置mmdetection2.8训练自定义coco数据集(一)_第4张图片

总结

 这是mmdetection系列第一篇,后续会开mmdetection源码解读等等。

参考资料

 mmdetection源码解读

你可能感兴趣的:(mmdetection,python,mmdetection,mmcv,pytorch,目标检测)