VOC数据集到RRPN所需数据格式转换脚本

1. 前言

在之前的文章中已经介绍了过了RRPN的原理,在给出的代码里面也写了相关数据转换的脚本,其实只需要理解了其内部所需数据的类型就可以按照自己的意愿将现有的数据集转换过去,而且该算法可以做到多类别检测,只是需要很小的修改就可以实现。

2. 转换脚本

def get_PascalVOC_2007():
    DATASET_DIR = 'xxx/data/VOCdevkit2007'  # VOC数据集路径

    img_dir = os.path.join(DATASET_DIR, 'VOC2007', 'JPEGImages')
    gt_dir = os.path.join(DATASET_DIR, 'VOC2007', 'Annotations')
    img_set = os.path.join(DATASET_DIR, 'VOC2007', 'ImageSets', 'Main', 'trainval.txt')


    num_classes = 5  # 分类的数目
    box_classes = ('__background__', 'class1', 'class2', 'class3', 'class4')
    box_classes_to_ind = dict(zip(box_classes, np.arange(num_classes)))

    im_infos = []  # 标注信息

    with open(img_set, 'r') as f:
        gt_file_list = f.readlines()
    print("buddy: all labeled image size: {}".format(len(gt_file_list)))

    # gt_file_list = gt_file_list[0:100]

    for index, gt_file in enumerate(gt_file_list):
        if index % 10000 == 0:
            print('buddy: load process-{}/{}'.format(index, len(gt_file_list)))
        gt_file = gt_file.strip('\n')
        gt_fobj = open(os.path.join(gt_dir, gt_file + '.xml'))
        tree = ET.parse(gt_fobj)
        objs = tree.findall('object')  # 找到所有的标注框节点

        img_name = os.path.join(img_dir, gt_file + '.jpg')
        img = cv2.imread(img_name)

        len_of_bboxes = len(objs)
        gt_boxes = np.zeros((len_of_bboxes, 5), dtype=np.int16)
        gt_classes = np.zeros((len_of_bboxes), dtype=np.int32)
        overlaps = np.zeros((len_of_bboxes, num_classes), dtype=np.float32)  # text or non-text
        seg_areas = np.zeros((len_of_bboxes), dtype=np.float32)

        # Load object bounding boxes into a data frame.
        for idx, obj in enumerate(objs):
            bbox = obj.find('bndbox')
            # Make pixel indexes 0-based
            x1 = float(bbox.find('xmin').text) - 1
            y1 = float(bbox.find('ymin').text) - 1
            x2 = float(bbox.find('xmax').text) - 1
            y2 = float(bbox.find('ymax').text) - 1
            cls = box_classes_to_ind[obj.find('name').text.lower().strip()]
            width = x2 - x1
            height = y2 - y1
            x_ctr = x1 + width/2
            y_ctr = y1 + height/2
            gt_boxes[idx, :] = [x_ctr, y_ctr, height, width, 0]
            gt_classes[idx] = cls
            overlaps[idx, cls] = 1.0
            seg_areas[idx] = (x2 - x1 + 1) * (y2 - y1 + 1)

        max_overlaps = overlaps.max(axis=1)
        # gt class that had the max overlap
        max_classes = overlaps.argmax(axis=1)

        im_info = {
            'gt_classes': gt_classes,
            'max_classes': max_classes,
            'image': os.path.join(img_dir, gt_file + '.jpg'),
            'boxes': gt_boxes,
            'flipped': False,
            'gt_overlaps': overlaps,
            'seg_areas': seg_areas,
            'height': img.shape[0],
            'width': img.shape[1],
            'max_overlaps': max_overlaps,
            'rotated': True  # 使用旋转增广数据
        }
        im_infos.append(im_info)
    return im_infos

你可能感兴趣的:([3],Python相关)