由于mmdetection对coco数据集的支持较好且提供了COCO API。但个人在进行学习时利用标注库LabelImg进行标注得到的是与图像一一对应的xml标注文件。COCO数据集则使用的是j标注信息集合为一个整体的json标注文件。而在mmdetection的tools当中提供了VOC数据集格式转换为COCO数据集格式的工具—dataset_convertor,可以将自己的类VOC的数据集转换为COCO数据集并进行训练。
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import xml.etree.ElementTree as ET
import mmcv
import numpy as np
from mmdet.core import voc_classes
label_ids = {name: i for i, name in enumerate(voc_classes())}
def parse_xml(args):
xml_path, img_path = args
tree = ET.parse(xml_path) #读取xml文件
root = tree.getroot()
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
bboxes = []
labels = []
bboxes_ignore = []
labels_ignore = []
for obj in root.findall('object'):
name = obj.find('name').text
label = label_ids[name]
difficult = int(obj.find('difficult').text)
bnd_box = obj.find('bndbox')
bbox = [
int(bnd_box.find('xmin').text),
int(bnd_box.find('ymin').text),
int(bnd_box.find('xmax').text),
int(bnd_box.find('ymax').text)
]
#读取各类标注出的object class信息和boundingbox信息
if difficult:
bboxes_ignore.append(bbox)
labels_ignore.append(label)
else:
bboxes.append(bbox)
labels.append(label)
if not bboxes:
bboxes = np.zeros((0, 4))
labels = np.zeros((0, ))
else:
bboxes = np.array(bboxes, ndmin=2) - 1
labels = np.array(labels)
if not bboxes_ignore:
bboxes_ignore = np.zeros((0, 4))
labels_ignore = np.zeros((0, ))
else:
bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1
labels_ignore = np.array(labels_ignore)
annotation = {
'filename': img_path.split('/')[2],
'width': w,
'height': h,
'ann': {
'bboxes': bboxes.astype(np.float32),
'labels': labels.astype(np.int64),
'bboxes_ignore': bboxes_ignore.astype(np.float32),
'labels_ignore': labels_ignore.astype(np.int64) #上述操作将读得的xml注释信息按照COCO标注格式添加至annotation当中
#循环完成后返回cvt_annotations当中
}
}
return annotation
def cvt_annotations(devkit_path, years, split, out_file):
#传入参数:VOC数据集地址,年份信息,split,输出结果(json文件)
if not isinstance(years, list):
years = [years] #如果years当中只包含一年则将一年转换为列表方便下方遍历
annotations = []
for year in years:
filelist = osp.join(devkit_path,
f'VOC{year}/ImageSets/Main/{split}.txt') #读取包含各个集当中文件名的.txt文件
if not osp.isfile(filelist):
print(f'filelist does not exist: {filelist}, '
f'skip voc{year} {split}')
return
img_names = mmcv.list_from_file(filelist) #从上述txt文件当中获取文件名生成列表
#img_names=['001552', '000350', '002229', '002597', '001165', '002648',...]
xml_paths = [
osp.join(devkit_path, f'VOC{year}/Annotations/{img_name}.xml') #获取xml标注文件地址生成list
for img_name in img_names
]
#xml_paths=['/home/lym/darknet/VOC2012/Annotations/001552.xml', '/home/lym/darknet/VOC2012/Annotations/000350.xml', '/home/lym/darknet/VOC2012/Annotations/002229.xml', '/home/lym/darknet/VOC2012/Annotations/002597.xml',...]
img_paths = [
f'VOC{year}/JPEGImages/{img_name}.jpg' for img_name in img_names #获取图像jpg文件地址生成list
]
#img_paths=['VOC2012/JPEGImages/001552.jpg', 'VOC2012/JPEGImages/000350.jpg', 'VOC2012/JPEGImages/002229.jpg', 'VOC2012/JPEGImages/002597.jpg',...]
part_annotations = mmcv.track_progress(parse_xml,
list(zip(xml_paths, img_paths)))
#track_progress为一个迭代器,追踪执行的任务并给出进度条,跳转到parse_xml函数
annotations.extend(part_annotations)
if out_file.endswith('json'):
annotations = cvt_to_coco_json(annotations)#进入cvt_to_coco_json函数
mmcv.dump(annotations, out_file) #dump函数根据annotations生成Json文件
return annotations
def cvt_to_coco_json(annotations):#该函数将
image_id = 0
annotation_id = 0
coco = dict()
coco['images'] = []
coco['type'] = 'instance'
coco['categories'] = []
coco['annotations'] = []
image_set = set()
def addAnnItem(annotation_id, image_id, category_id, bbox, difficult_flag):
annotation_item = dict()
annotation_item['segmentation'] = []
seg = []
# bbox[] is x1,y1,x2,y2
# left_top
seg.append(int(bbox[0]))
seg.append(int(bbox[1]))
# left_bottom
seg.append(int(bbox[0]))
seg.append(int(bbox[3]))
# right_bottom
seg.append(int(bbox[2]))
seg.append(int(bbox[3]))
# right_top
seg.append(int(bbox[2]))
seg.append(int(bbox[1]))
annotation_item['segmentation'].append(seg)
xywh = np.array(
[bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]])
annotation_item['area'] = int(xywh[2] * xywh[3])
if difficult_flag == 1:
annotation_item['ignore'] = 0
annotation_item['iscrowd'] = 1
else:
annotation_item['ignore'] = 0
annotation_item['iscrowd'] = 0
annotation_item['image_id'] = int(image_id)
annotation_item['bbox'] = xywh.astype(int).tolist()
annotation_item['category_id'] = int(category_id)
annotation_item['id'] = int(annotation_id)
coco['annotations'].append(annotation_item)
return annotation_id + 1
for category_id, name in enumerate(voc_classes()):
category_item = dict()
category_item['supercategory'] = str('none')
category_item['id'] = int(category_id)
category_item['name'] = str(name)
coco['categories'].append(category_item)
for ann_dict in annotations:
file_name = ann_dict['filename']
ann = ann_dict['ann']
assert file_name not in image_set
image_item = dict()
image_item['id'] = int(image_id)
image_item['file_name'] = str(file_name)
image_item['height'] = int(ann_dict['height'])
image_item['width'] = int(ann_dict['width'])
coco['images'].append(image_item)
image_set.add(file_name)
bboxes = ann['bboxes'][:, :4]
labels = ann['labels']
for bbox_id in range(len(bboxes)):
bbox = bboxes[bbox_id]
label = labels[bbox_id]
annotation_id = addAnnItem(
annotation_id, image_id, label, bbox, difficult_flag=0)
bboxes_ignore = ann['bboxes_ignore'][:, :4]
labels_ignore = ann['labels_ignore']
for bbox_id in range(len(bboxes_ignore)):
bbox = bboxes_ignore[bbox_id]
label = labels_ignore[bbox_id]
annotation_id = addAnnItem(
annotation_id, image_id, label, bbox, difficult_flag=1)
image_id += 1
return coco
def parse_args():
parser = argparse.ArgumentParser(
description='Convert PASCAL VOC annotations to mmdetection format')
parser.add_argument('devkit_path', help='pascal voc devkit path')
parser.add_argument('-o', '--out-dir', help='output path')
parser.add_argument(
'--out-format',
default='coco',
choices=('pkl', 'coco'),
help='output format, "coco" indicates coco annotation format')
args = parser.parse_args()
return args
def main():
args = parse_args() #首先获取输入运行代码时输入的参数
devkit_path = args.devkit_path #devkit则为(类)VOC格式的数据集地址
out_dir = args.out_dir if args.out_dir else devkit_path #若未输入json的输出地址则选择为原始地址输出
mmcv.mkdir_or_exist(out_dir) #不存在参数中的输出地址则建立一个该文件夹
years = []
if osp.isdir(osp.join(devkit_path, 'VOC2007')):
years.append('2007')
if osp.isdir(osp.join(devkit_path, 'VOC2012')):
years.append('2012')
if '2007' in years and '2012' in years:
years.append(['2007', '2012']) #根据VOC数据集的年份信息读取传递给year
if not years:
raise IOError(f'The devkit path {devkit_path} contains neither '
'"VOC2007" nor "VOC2012" subfolder') #提供的地址中没有关键词为2007 2012的VOC数据集
out_fmt = f'.{args.out_format}'
if args.out_format == 'coco':
out_fmt = '.json' #根据所需生成的标注文件生成对应的文件后缀名
for year in years:
if year == '2007':
prefix = 'voc07'
elif year == '2012':
prefix = 'voc12'
elif year == ['2007', '2012']:
prefix = 'voc0712'
for split in ['train', 'val', 'trainval']: #根据split当中划分的不同训练测试组分分别进行处理
dataset_name = prefix + '_' + split
print(f'processing {dataset_name} ...')
cvt_annotations(devkit_path, year, split,
osp.join(out_dir, dataset_name + out_fmt)) #跳转到cvt_annotations函数
if not isinstance(year, list):
dataset_name = prefix + '_test'
print(f'processing {dataset_name} ...')
cvt_annotations(devkit_path, year, 'test',
osp.join(out_dir, dataset_name + out_fmt))
print('Done!')
if __name__ == '__main__':
main()
验证创建文件夹
def mkdir_or_exist(dir_name, mode=0o777):
if dir_name == '':
return
dir_name = osp.expanduser(dir_name)
os.makedirs(dir_name, mode=mode, exist_ok=True)
读取text文档并输出至list
def list_from_file(filename, prefix='', offset=0, max_num=0, encoding='utf-8'):
"""Load a text file and parse the content as a list of strings.
Args:
filename (str): Filename.
prefix (str): The prefix to be inserted to the begining of each item.
offset (int): The offset of lines.
max_num (int): The maximum number of lines to be read,
zeros and negatives mean no limitation.
encoding (str): Encoding used to open the file. Default utf-8.
Returns:
list[str]: A list of strings.
"""
cnt = 0
item_list = []
with open(filename, 'r', encoding=encoding) as f:
for _ in range(offset):
f.readline()
for line in f:
if 0 < max_num <= cnt:
break
item_list.append(prefix + line.rstrip('\n\r'))
cnt += 1
return item_list
创建进度条并将迭代内容输入至指定函数当中得到结果
def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
"""Track the progress of tasks execution with a progress bar.
Tasks are done with a simple for-loop.
Args:
func (callable): The function to be applied to each task.
tasks (list or tuple[Iterable, int]): A list of tasks or
(tasks, total num).
bar_width (int): Width of progress bar.
Returns:
list: The task results.
"""
if isinstance(tasks, tuple):
assert len(tasks) == 2
assert isinstance(tasks[0], Iterable)
assert isinstance(tasks[1], int)
task_num = tasks[1]
tasks = tasks[0]
elif isinstance(tasks, Iterable): #执行这一步,tasks输入的是list可迭代器
task_num = len(tasks)
else:
raise TypeError(
'"tasks" must be an iterable object or a (iterator, int) tuple')
prog_bar = ProgressBar(task_num, bar_width, file=file)
results = []
for task in tasks:
results.append(func(task, **kwargs))#循环将每一个task当中的内容输至目标函数当中并将结果append
prog_bar.update()
prog_bar.file.write('\n')
return results
dump函数:根据最终得到的annotations生成所需的json文件
def dump(obj, file=None, file_format=None, **kwargs):
"""Dump data to json/yaml/pickle strings or files.
This method provides a unified api for dumping data as strings or to files,
and also supports custom arguments for each file format.
Args:
obj (any): The python object to be dumped.
file (str or :obj:`Path` or file-like object, optional): If not
specified, then the object is dump to a str, otherwise to a file
specified by the filename or file-like object.
file_format (str, optional): Same as :func:`load`.
Returns:
bool: True for success, False otherwise.
"""
if isinstance(file, Path):
file = str(file)
if file_format is None:
if is_str(file):
file_format = file.split('.')[-1] #得到生成文件的后缀名
elif file is None:
raise ValueError(
'file_format must be specified since file is None')
if file_format not in file_handlers:
#file_handlers中包括file_handlers = {
# 'json': JsonHandler(),
#'yaml': YamlHandler(),
#'yml': YamlHandler(),
#'pickle': PickleHandler(),
#'pkl': PickleHandler()
#}
raise TypeError(f'Unsupported format: {file_format}')
handler = file_handlers[file_format]
if file is None:
return handler.dump_to_str(obj, **kwargs)
elif is_str(file):
handler.dump_to_path(obj, file, **kwargs)
elif hasattr(file, 'write'):
handler.dump_to_fileobj(obj, file, **kwargs)
else:
raise TypeError('"file" must be a filename str or a file-object')