【庖丁解牛】从零实现RetinaNet(一):COCO与VOC数据集处理

文章目录

  • 前言
  • COCO数据集介绍
  • VOC数据集介绍
  • COCO和VOC数据集文件组织结构
  • COCO数据集处理
  • VOC数据集处理

所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。

前言

经过前面的base model系列ImageNet训练实践,笔者终于要开始学习目标检测了。目标检测这块的细节特别多,而这些细节在论文中通常不会提及(往往是继承以前的目标检测器的做法),因此只有在代码中才能更好的了解这些细节。学习的最好方法就是自己实现一个目标检测器。在本系列中,笔者将从零开始实现单阶段目标检测器RetinaNet,包含数据集处理、数据增强、网络结构、loss、decode等部分。

COCO数据集介绍

COCO数据集官方网站地址:http://cocodataset.org/#home 。COCO是一个大规模目标检测数据集。COCO数据集每年都会更新,但是在目标检测论文中我们只会用到COCO2014与COCO2017数据集。COCO2017数据集包括三个子集:train(118287张图片)、val(5000张图片)、test(40670张图片),共有80个类。其中train和val集都提供了ground truth,test集没有ground turth,需要把detect结果提交到COCO数据集官网上测试才能得到结果。

COCO2014与COCO2017数据集的区别?
在RetinaNet论文中提供的Detectron开源代码(https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md)中我们可以找到相关解释:RetinaNet论文中所有模型都是在coco_2014_train数据集(82783张图片)和coco_2014_valminusminival数据集(共有40504张图片)随机划分的含35504张图片的子集的并集上进行训练,这个并集实际上与coco_2017_train数据集完全一致。在测试时,所有模型都在coco_2014_minival数据集剩下的含有5000张图片的另一个子集上进行测试,这个子集实际上与coco_2017_val数据集完全一致。也就是说,RetinaNet论文中实际上就是用coco_2017_train数据集训练模型,用coco_2017_val数据集测试模型。
在RetinaNet论文中,模型的表现指IoU=0.5:0.95下,最多保留100个detect目标,保留所有大小的目标下的mAP(即pycocotools.cocoeval的COCOeval类中_summarizeDets函数中的stats[0]值)。

模型在val数据集上和test数据集上的表现差多少?
由于train、val、test集实际上都是从同一个母数据集随机划分成三部分得到的,模型在val集和test集中的表现差距很小。根据其他有同时在val和test上测试模型的论文中给出的结果,一般在val和test集上模型的mAP相差在0.2-0.3个百分点左右。

在接下来的复现中,我们遵循RetinaNet论文中的数据集设置,使用coco_2017_train数据集训练模型,使用coco_2017_val数据集测试模型。使用IoU=0.5:0.95下,最多保留100个detect目标,保留所有大小的目标下的mAP(即pycocotools.cocoeval的COCOeval类中_summarizeDets函数中的stats[0]值)作为模型的性能表现。

VOC数据集介绍

VOC数据集官方网站地址:http://host.robots.ox.ac.uk/pascal/VOC/ 。VOC也是一个目标检测数据集,但规模要比COCO数据集小的多。在目标检测论文中我们通常用VOC2007和VOC2012。和COCO数据集一样,VOC2007和VOC2012都分为train、val、test三个子集,共有20个类。对VOC2007,train、val、test三个子集都提供了ground truth。对VOC2012,只有train、val两个子集提供了ground truth。

我们参照detectron2中使用faster rcnn在VOC数据集上训练测试的做法(https://github.com/facebookresearch/detectron2/blob/master/MODEL_ZOO.md),使用VOC2007trainval+VOC2012trainval数据集训练模型,使用VOC2007test数据集测试模型。测试时使用VOC2007的11 point metric方式计算mAP。

COCO和VOC数据集文件组织结构

我们下载好COCO数据集和VOC数据集后,将文件夹组织结构调整成下面这样:

COCO2017
|
|
|----annotations----contains all annotaion json files
|
|                  |----train2017
|----images--------|----val2017
                   |----test2017

VOCdataset
|
|
|                  |----Annotations
|                  |----ImageSets
|----VOC2007-------|----JPEGImages
|                  |----SegmentationClass
|                  |----SegmentationObject
|
|                  |----Annotations
|                  |----ImageSets
|----VOC2012-------|----JPEGImages
|                  |----SegmentationClass
|                  |----SegmentationObject

COCO数据集处理

COCO2017数据集标注中提供的原始box坐标是[x_min,y_min,w,h],即框左上角坐标和框的宽高,我们会将这个box坐标转换为[x_min,y_min,x_max,y_max],即框左上角坐标和框右下角坐标。同时,标注中也提供了类别index,但是原始标注的类别index不连续(1-90,但是只有80个类),我们要将其转换成连续的类别index0-79。
处理COCO数据集的代码如下:

import os
import cv2
import torch
import numpy as np
import random
from torch.utils.data import Dataset
from pycocotools.coco import COCO
import torch.nn.functional as F

COCO_CLASSES = [
    "person",
    "bicycle",
    "car",
    "motorcycle",
    "airplane",
    "bus",
    "train",
    "truck",
    "boat",
    "traffic light",
    "fire hydrant",
    "stop sign",
    "parking meter",
    "bench",
    "bird",
    "cat",
    "dog",
    "horse",
    "sheep",
    "cow",
    "elephant",
    "bear",
    "zebra",
    "giraffe",
    "backpack",
    "umbrella",
    "handbag",
    "tie",
    "suitcase",
    "frisbee",
    "skis",
    "snowboard",
    "sports ball",
    "kite",
    "baseball bat",
    "baseball glove",
    "skateboard",
    "surfboard",
    "tennis racket",
    "bottle",
    "wine glass",
    "cup",
    "fork",
    "knife",
    "spoon",
    "bowl",
    "banana",
    "apple",
    "sandwich",
    "orange",
    "broccoli",
    "carrot",
    "hot dog",
    "pizza",
    "donut",
    "cake",
    "chair",
    "couch",
    "potted plant",
    "bed",
    "dining table",
    "toilet",
    "tv",
    "laptop",
    "mouse",
    "remote",
    "keyboard",
    "cell phone",
    "microwave",
    "oven",
    "toaster",
    "sink",
    "refrigerator",
    "book",
    "clock",
    "vase",
    "scissors",
    "teddy bear",
    "hair drier",
    "toothbrush",
]

colors = [
    (39, 129, 113),
    (164, 80, 133),
    (83, 122, 114),
    (99, 81, 172),
    (95, 56, 104),
    (37, 84, 86),
    (14, 89, 122),
    (80, 7, 65),
    (10, 102, 25),
    (90, 185, 109),
    (106, 110, 132),
    (169, 158, 85),
    (188, 185, 26),
    (103, 1, 17),
    (82, 144, 81),
    (92, 7, 184),
    (49, 81, 155),
    (179, 177, 69),
    (93, 187, 158),
    (13, 39, 73),
    (12, 50, 60),
    (16, 179, 33),
    (112, 69, 165),
    (15, 139, 63),
    (33, 191, 159),
    (182, 173, 32),
    (34, 113, 133),
    (90, 135, 34),
    (53, 34, 86),
    (141, 35, 190),
    (6, 171, 8),
    (118, 76, 112),
    (89, 60, 55),
    (15, 54, 88),
    (112, 75, 181),
    (42, 147, 38),
    (138, 52, 63),
    (128, 65, 149),
    (106, 103, 24),
    (168, 33, 45),
    (28, 136, 135),
    (86, 91, 108),
    (52, 11, 76),
    (142, 6, 189),
    (57, 81, 168),
    (55, 19, 148),
    (182, 101, 89),
    (44, 65, 179),
    (1, 33, 26),
    (122, 164, 26),
    (70, 63, 134),
    (137, 106, 82),
    (120, 118, 52),
    (129, 74, 42),
    (182, 147, 112),
    (22, 157, 50),
    (56, 50, 20),
    (2, 22, 177),
    (156, 100, 106),
    (21, 35, 42),
    (13, 8, 121),
    (142, 92, 28),
    (45, 118, 33),
    (105, 118, 30),
    (7, 185, 124),
    (46, 34, 146),
    (105, 184, 169),
    (22, 18, 5),
    (147, 71, 73),
    (181, 64, 91),
    (31, 39, 184),
    (164, 179, 33),
    (96, 50, 18),
    (95, 15, 106),
    (113, 68, 54),
    (136, 116, 112),
    (119, 139, 130),
    (31, 139, 34),
    (66, 6, 127),
    (62, 39, 2),
    (49, 99, 180),
    (49, 119, 155),
    (153, 50, 183),
    (125, 38, 3),
    (129, 87, 143),
    (49, 87, 40),
    (128, 62, 120),
    (73, 85, 148),
    (28, 144, 118),
    (29, 9, 24),
    (175, 45, 108),
    (81, 175, 64),
    (178, 19, 157),
    (74, 188, 190),
    (18, 114, 2),
    (62, 128, 96),
    (21, 3, 150),
    (0, 6, 95),
    (2, 20, 184),
    (122, 37, 185),
]


class CocoDetection(Dataset):
    def __init__(self,
                 image_root_dir,
                 annotation_root_dir,
                 set='train2017',
                 transform=None):
        self.image_root_dir = image_root_dir
        self.annotation_root_dir = annotation_root_dir
        self.set_name = set
        self.transform = transform

        self.coco = COCO(
            os.path.join(self.annotation_root_dir,
                         'instances_' + self.set_name + '.json'))

        self.load_classes()

    def load_classes(self):
        self.image_ids = self.coco.getImgIds()
        self.cat_ids = self.coco.getCatIds()
        self.categories = self.coco.loadCats(self.cat_ids)
        self.categories.sort(key=lambda x: x['id'])

        # category_id is an original id,coco_id is set from 0 to 79
        self.category_id_to_coco_label = {
            category['id']: i
            for i, category in enumerate(self.categories)
        }
        self.coco_label_to_category_id = {
            v: k
            for k, v in self.category_id_to_coco_label.items()
        }

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img = self.load_image(idx)
        annot = self.load_annotations(idx)

        sample = {'img': img, 'annot': annot, 'scale': 1.}
        if self.transform:
            sample = self.transform(sample)
        return sample

    def load_image(self, image_index):
        image_info = self.coco.loadImgs(self.image_ids[image_index])[0]
        path = os.path.join(self.image_root_dir, image_info['file_name'])
        img = cv2.imread(path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        return img.astype(np.float32) / 255.

    def load_annotations(self, image_index):
        # get ground truth annotations
        annotations_ids = self.coco.getAnnIds(
            imgIds=self.image_ids[image_index], iscrowd=None)
        annotations = np.zeros((0, 5))

        # some images appear to miss annotations
        if len(annotations_ids) == 0:
            return annotations

        # parse annotations
        coco_annotations = self.coco.loadAnns(annotations_ids)
        for _, a in enumerate(coco_annotations):
            # some annotations have basically no width / height, skip them
            if a['bbox'][2] < 1 or a['bbox'][3] < 1:
                continue

            annotation = np.zeros((1, 5))
            annotation[0, :4] = a['bbox']
            annotation[0, 4] = self.find_coco_label_from_category_id(
                a['category_id'])

            annotations = np.append(annotations, annotation, axis=0)

        # transform from [x_min, y_min, w, h] to [x_min, y_min, x_max, y_max]
        annotations[:, 2] = annotations[:, 0] + annotations[:, 2]
        annotations[:, 3] = annotations[:, 1] + annotations[:, 3]

        return annotations

    def find_coco_label_from_category_id(self, category_id):
        return self.category_id_to_coco_label[category_id]

    def find_category_id_from_coco_label(self, coco_label):
        return self.coco_label_to_category_id[coco_label]

    def num_classes(self):
        return 80

    def image_aspect_ratio(self, image_index):
        image = self.coco.loadImgs(self.image_ids[image_index])[0]
        return float(image['width']) / float(image['height'])

该类遍历的每一个对象就是一张图片的相关信息(在一个字典里),键’img’对应的值就是图片,键’annot’对应的numpy数组就是这张图片标注的对象。注意每张图片标注的对象数量不一定一样,也有可能某张图片没有标注对象。

VOC数据集处理

VOC数据集标注中提供的原始box坐标就是[x_min,y_min,x_max,y_max],因此不需要转换坐标。标注中只提供了类别的name,我们要将其映射为类别index0-19。
处理VOC数据集的代码如下:

import os
import cv2
import numpy as np
import random
import xml.etree.ElementTree as ET

import torch
from torch.utils.data import Dataset

VOC_CLASSES = [
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
    "diningtable",
    "dog",
    "horse",
    "motorbike",
    "person",
    "pottedplant",
    "sheep",
    "sofa",
    "train",
    "tvmonitor",
]

colors = [
    (39, 129, 113),
    (164, 80, 133),
    (83, 122, 114),
    (99, 81, 172),
    (95, 56, 104),
    (37, 84, 86),
    (14, 89, 122),
    (80, 7, 65),
    (10, 102, 25),
    (90, 185, 109),
    (106, 110, 132),
    (169, 158, 85),
    (188, 185, 26),
    (103, 1, 17),
    (82, 144, 81),
    (92, 7, 184),
    (49, 81, 155),
    (179, 177, 69),
    (93, 187, 158),
    (13, 39, 73),
]


class VocDetection(Dataset):
    def __init__(self,
                 root_dir,
                 image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
                 transform=None,
                 keep_difficult=False):
        self.root_dir = root_dir
        self.image_set = image_sets
        self.transform = transform
        self.categories = VOC_CLASSES

        self.category_id_to_voc_label = dict(
            zip(self.categories, range(len(self.categories))))
        self.voc_label_to_category_id = {
            v: k
            for k, v in self.category_id_to_voc_label.items()
        }

        self.keep_difficult = keep_difficult

        self._annopath = os.path.join('%s', 'Annotations', '%s.xml')
        self._imgpath = os.path.join('%s', 'JPEGImages', '%s.jpg')
        self.ids = list()
        for (year, name) in image_sets:
            rootpath = os.path.join(self.root_dir, 'VOC' + year)
            for line in open(
                    os.path.join(rootpath, 'ImageSets', 'Main',
                                 name + '.txt')):
                self.ids.append((rootpath, line.strip()))

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        img = self.load_image(img_id)

        target = ET.parse(self._annopath % img_id).getroot()
        annot = self.load_annotations(target)

        sample = {'img': img, 'annot': annot, 'scale': 1.}

        if self.transform:
            sample = self.transform(sample)
        return sample

    def load_image(self, img_id):
        img = cv2.imread(self._imgpath % img_id)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        return img.astype(np.float32) / 255.

    def load_annotations(self, target):
        annotations = []
        for obj in target.iter('object'):
            difficult = int(obj.find('difficult').text) == 1
            if not self.keep_difficult and difficult:
                continue
            name = obj.find('name').text.lower().strip()
            bbox = obj.find('bndbox')

            pts = ['xmin', 'ymin', 'xmax', 'ymax']

            bndbox = []
            for pt in pts:
                cur_pt = float(bbox.find(pt).text)
                bndbox.append(cur_pt)
            label_idx = self.category_id_to_voc_label[name]
            bndbox.append(label_idx)
            annotations += [bndbox]  # [xmin, ymin, xmax, ymax, label_ind]
            # img_id = target.find('filename').text[:-4]

        annotations = np.array(annotations)
        # format:[[x1, y1, x2, y2, label_ind], ... ]
        return annotations

    def find_category_id_from_voc_label(self, voc_label):
        return self.voc_label_to_category_id[voc_label]

    def image_aspect_ratio(self, idx):
        img_id = self.ids[idx]
        image = self.load_image(img_id)
        #w/h
        return float(image.shape[1]) / float(image.shape[0])

    def __len__(self):
        return len(self.ids)

和COCO类类似,该类遍历的每一个对象就是一张图片的相关信息(在一个字典里),键’img’对应的值就是图片,键’annot’对应的numpy数组就是这张图片标注的对象。注意每张图片标注的对象数量不一定一样,也有可能某张图片没有标注对象。

你可能感兴趣的:(深度学习,人工智能,计算机视觉)