#OpenMMLab-mmdetection代码注释学习(1)-dataset_converters【自用】

由于mmdetection对coco数据集的支持较好且提供了COCO API。但个人在进行学习时利用标注库LabelImg进行标注得到的是与图像一一对应的xml标注文件。COCO数据集则使用的是j标注信息集合为一个整体的json标注文件。而在mmdetection的tools当中提供了VOC数据集格式转换为COCO数据集格式的工具—dataset_convertor,可以将自己的类VOC的数据集转换为COCO数据集并进行训练。

整体代码结构

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import xml.etree.ElementTree as ET

import mmcv
import numpy as np

from mmdet.core import voc_classes

label_ids = {name: i for i, name in enumerate(voc_classes())}


def parse_xml(args):
    xml_path, img_path = args
    tree = ET.parse(xml_path)   #读取xml文件   
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)
    bboxes = []
    labels = []
    bboxes_ignore = []
    labels_ignore = []
    for obj in root.findall('object'):
        name = obj.find('name').text  
        label = label_ids[name]
        difficult = int(obj.find('difficult').text)
        bnd_box = obj.find('bndbox')
        bbox = [
            int(bnd_box.find('xmin').text),
            int(bnd_box.find('ymin').text),
            int(bnd_box.find('xmax').text),
            int(bnd_box.find('ymax').text)
        ]
        #读取各类标注出的object class信息和boundingbox信息
        if difficult:
            bboxes_ignore.append(bbox)
            labels_ignore.append(label)
        else:
            bboxes.append(bbox)
            labels.append(label)
    if not bboxes:
        bboxes = np.zeros((0, 4))
        labels = np.zeros((0, ))
    else:
        bboxes = np.array(bboxes, ndmin=2) - 1
        labels = np.array(labels)
    if not bboxes_ignore:
        bboxes_ignore = np.zeros((0, 4))
        labels_ignore = np.zeros((0, ))
    else:
        bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1
        labels_ignore = np.array(labels_ignore)
    annotation = {
        'filename': img_path.split('/')[2],
        'width': w,
        'height': h,
        'ann': {
            'bboxes': bboxes.astype(np.float32),
            'labels': labels.astype(np.int64),
            'bboxes_ignore': bboxes_ignore.astype(np.float32),
            'labels_ignore': labels_ignore.astype(np.int64)  #上述操作将读得的xml注释信息按照COCO标注格式添加至annotation当中
            #循环完成后返回cvt_annotations当中
        }
    }
    return annotation


def cvt_annotations(devkit_path, years, split, out_file): 
#传入参数:VOC数据集地址,年份信息,split,输出结果(json文件)
    if not isinstance(years, list):
        years = [years]  #如果years当中只包含一年则将一年转换为列表方便下方遍历
    annotations = []  
    for year in years:
        filelist = osp.join(devkit_path,
                            f'VOC{year}/ImageSets/Main/{split}.txt')     #读取包含各个集当中文件名的.txt文件
        if not osp.isfile(filelist):
            print(f'filelist does not exist: {filelist}, '
                  f'skip voc{year} {split}')
            return
        img_names = mmcv.list_from_file(filelist)   #从上述txt文件当中获取文件名生成列表
        #img_names=['001552', '000350', '002229', '002597', '001165', '002648',...]
        
        xml_paths = [
            osp.join(devkit_path, f'VOC{year}/Annotations/{img_name}.xml')       #获取xml标注文件地址生成list
            for img_name in img_names
        ]
        #xml_paths=['/home/lym/darknet/VOC2012/Annotations/001552.xml', '/home/lym/darknet/VOC2012/Annotations/000350.xml', '/home/lym/darknet/VOC2012/Annotations/002229.xml', '/home/lym/darknet/VOC2012/Annotations/002597.xml',...]
        img_paths = [
            f'VOC{year}/JPEGImages/{img_name}.jpg' for img_name in img_names   #获取图像jpg文件地址生成list
        ]
        #img_paths=['VOC2012/JPEGImages/001552.jpg', 'VOC2012/JPEGImages/000350.jpg', 'VOC2012/JPEGImages/002229.jpg', 'VOC2012/JPEGImages/002597.jpg',...]
        part_annotations = mmcv.track_progress(parse_xml,
                                               list(zip(xml_paths, img_paths)))     
       #track_progress为一个迭代器,追踪执行的任务并给出进度条,跳转到parse_xml函数
        annotations.extend(part_annotations)
    if out_file.endswith('json'):
        annotations = cvt_to_coco_json(annotations)#进入cvt_to_coco_json函数
    mmcv.dump(annotations, out_file) #dump函数根据annotations生成Json文件
    return annotations


def cvt_to_coco_json(annotations):#该函数将
    image_id = 0
    annotation_id = 0
    coco = dict()
    coco['images'] = []
    coco['type'] = 'instance'
    coco['categories'] = []
    coco['annotations'] = []
    image_set = set()

    def addAnnItem(annotation_id, image_id, category_id, bbox, difficult_flag):
        annotation_item = dict()
        annotation_item['segmentation'] = []

        seg = []
        # bbox[] is x1,y1,x2,y2
        # left_top
        seg.append(int(bbox[0]))
        seg.append(int(bbox[1]))
        # left_bottom
        seg.append(int(bbox[0]))
        seg.append(int(bbox[3]))
        # right_bottom
        seg.append(int(bbox[2]))
        seg.append(int(bbox[3]))
        # right_top
        seg.append(int(bbox[2]))
        seg.append(int(bbox[1]))

        annotation_item['segmentation'].append(seg)

        xywh = np.array(
            [bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]])
        annotation_item['area'] = int(xywh[2] * xywh[3])
        if difficult_flag == 1:
            annotation_item['ignore'] = 0
            annotation_item['iscrowd'] = 1
        else:
            annotation_item['ignore'] = 0
            annotation_item['iscrowd'] = 0
        annotation_item['image_id'] = int(image_id)
        annotation_item['bbox'] = xywh.astype(int).tolist()
        annotation_item['category_id'] = int(category_id)
        annotation_item['id'] = int(annotation_id)
        coco['annotations'].append(annotation_item)
        return annotation_id + 1

    for category_id, name in enumerate(voc_classes()):
        category_item = dict()
        category_item['supercategory'] = str('none')
        category_item['id'] = int(category_id)
        category_item['name'] = str(name)
        coco['categories'].append(category_item)

    for ann_dict in annotations:
        file_name = ann_dict['filename']
        ann = ann_dict['ann']
        assert file_name not in image_set
        image_item = dict()
        image_item['id'] = int(image_id)
        image_item['file_name'] = str(file_name)
        image_item['height'] = int(ann_dict['height'])
        image_item['width'] = int(ann_dict['width'])
        coco['images'].append(image_item)
        image_set.add(file_name)

        bboxes = ann['bboxes'][:, :4]
        labels = ann['labels']
        for bbox_id in range(len(bboxes)):
            bbox = bboxes[bbox_id]
            label = labels[bbox_id]
            annotation_id = addAnnItem(
                annotation_id, image_id, label, bbox, difficult_flag=0)

        bboxes_ignore = ann['bboxes_ignore'][:, :4]
        labels_ignore = ann['labels_ignore']
        for bbox_id in range(len(bboxes_ignore)):
            bbox = bboxes_ignore[bbox_id]
            label = labels_ignore[bbox_id]
            annotation_id = addAnnItem(
                annotation_id, image_id, label, bbox, difficult_flag=1)

        image_id += 1

    return coco


def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert PASCAL VOC annotations to mmdetection format')
    parser.add_argument('devkit_path', help='pascal voc devkit path')
    parser.add_argument('-o', '--out-dir', help='output path')
    parser.add_argument(
        '--out-format',
        default='coco',
        choices=('pkl', 'coco'),
        help='output format, "coco" indicates coco annotation format')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()                                        #首先获取输入运行代码时输入的参数
    devkit_path = args.devkit_path               #devkit则为(类)VOC格式的数据集地址
    out_dir = args.out_dir if args.out_dir else devkit_path     #若未输入json的输出地址则选择为原始地址输出
    mmcv.mkdir_or_exist(out_dir)               #不存在参数中的输出地址则建立一个该文件夹
    

    years = []
    if osp.isdir(osp.join(devkit_path, 'VOC2007')):
        years.append('2007')
    if osp.isdir(osp.join(devkit_path, 'VOC2012')):
        years.append('2012')
    if '2007' in years and '2012' in years:
        years.append(['2007', '2012'])                  #根据VOC数据集的年份信息读取传递给year
       
    if not years:
        raise IOError(f'The devkit path {devkit_path} contains neither '
                      '"VOC2007" nor "VOC2012" subfolder')   #提供的地址中没有关键词为2007 2012的VOC数据集
    out_fmt = f'.{args.out_format}'    
    if args.out_format == 'coco':
        out_fmt = '.json'                     #根据所需生成的标注文件生成对应的文件后缀名
    for year in years:
        if year == '2007':
            prefix = 'voc07'  
        elif year == '2012':
            prefix = 'voc12'
        elif year == ['2007', '2012']:
            prefix = 'voc0712'
        for split in ['train', 'val', 'trainval']:        #根据split当中划分的不同训练测试组分分别进行处理
            dataset_name = prefix + '_' + split
            print(f'processing {dataset_name} ...')                   
            cvt_annotations(devkit_path, year, split,  
                            osp.join(out_dir, dataset_name + out_fmt))          #跳转到cvt_annotations函数
        if not isinstance(year, list):
            dataset_name = prefix + '_test'
            print(f'processing {dataset_name} ...')
            cvt_annotations(devkit_path, year, 'test',
                            osp.join(out_dir, dataset_name + out_fmt))
    print('Done!')


if __name__ == '__main__':
    main()

mmcv库函数

验证创建文件夹

def mkdir_or_exist(dir_name, mode=0o777):
    if dir_name == '':
        return
    dir_name = osp.expanduser(dir_name)
    os.makedirs(dir_name, mode=mode, exist_ok=True)

读取text文档并输出至list

def list_from_file(filename, prefix='', offset=0, max_num=0, encoding='utf-8'):
    """Load a text file and parse the content as a list of strings.

    Args:
        filename (str): Filename.
        prefix (str): The prefix to be inserted to the begining of each item.
        offset (int): The offset of lines.
        max_num (int): The maximum number of lines to be read,
            zeros and negatives mean no limitation.
        encoding (str): Encoding used to open the file. Default utf-8.

    Returns:
        list[str]: A list of strings.
    """
    cnt = 0
    item_list = []
    with open(filename, 'r', encoding=encoding) as f:
        for _ in range(offset):
            f.readline()
        for line in f:
            if 0 < max_num <= cnt:
                break
            item_list.append(prefix + line.rstrip('\n\r'))
            cnt += 1
    return item_list

创建进度条并将迭代内容输入至指定函数当中得到结果

def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
    """Track the progress of tasks execution with a progress bar.

    Tasks are done with a simple for-loop.

    Args:
        func (callable): The function to be applied to each task.
        tasks (list or tuple[Iterable, int]): A list of tasks or
            (tasks, total num).
        bar_width (int): Width of progress bar.

    Returns:
        list: The task results.
    """
    if isinstance(tasks, tuple):
        assert len(tasks) == 2
        assert isinstance(tasks[0], Iterable)
        assert isinstance(tasks[1], int)
        task_num = tasks[1]
        tasks = tasks[0]
    elif isinstance(tasks, Iterable):  #执行这一步,tasks输入的是list可迭代器
        task_num = len(tasks)
    else:
        raise TypeError(
            '"tasks" must be an iterable object or a (iterator, int) tuple')
    prog_bar = ProgressBar(task_num, bar_width, file=file)
    results = []
    for task in tasks:
        results.append(func(task, **kwargs))#循环将每一个task当中的内容输至目标函数当中并将结果append
        prog_bar.update()
    prog_bar.file.write('\n')
    return results

dump函数:根据最终得到的annotations生成所需的json文件

def dump(obj, file=None, file_format=None, **kwargs):
    """Dump data to json/yaml/pickle strings or files.

    This method provides a unified api for dumping data as strings or to files,
    and also supports custom arguments for each file format.

    Args:
        obj (any): The python object to be dumped.
        file (str or :obj:`Path` or file-like object, optional): If not
            specified, then the object is dump to a str, otherwise to a file
            specified by the filename or file-like object.
        file_format (str, optional): Same as :func:`load`.

    Returns:
        bool: True for success, False otherwise.
    """
    if isinstance(file, Path):
        file = str(file)
    if file_format is None:
        if is_str(file):
            file_format = file.split('.')[-1]  #得到生成文件的后缀名
        elif file is None:
            raise ValueError(
                'file_format must be specified since file is None')
    if file_format not in file_handlers:
    #file_handlers中包括file_handlers = {
   # 'json': JsonHandler(),
    #'yaml': YamlHandler(),
    #'yml': YamlHandler(),
    #'pickle': PickleHandler(),
    #'pkl': PickleHandler()
#}
        raise TypeError(f'Unsupported format: {file_format}')

    handler = file_handlers[file_format]
    if file is None:
        return handler.dump_to_str(obj, **kwargs)
    elif is_str(file):
        handler.dump_to_path(obj, file, **kwargs)
    elif hasattr(file, 'write'):
        handler.dump_to_fileobj(obj, file, **kwargs)
    else:
        raise TypeError('"file" must be a filename str or a file-object')

你可能感兴趣的:(计算机视觉,深度学习,attention)