COCO API的安装,COCO数据集介绍以及读取接口代码解读(PyTorch接口)

COCO(Common Objects in Context)数据集是微软发布的大型数据集,可以用来目标检测,实例分割,语义分割,关键点检测,以及场景描述。在学术界,COCO基本上被分为两个版本,2014版和2017版。2017版是在14版的基础上做得扩充。数据集分为训练集,验证集和测试集。其中测试集在官网服务器上,Ground Truth未公布。
COCO的标注存放在json文件中。以2017版为例子。
COCO API的安装,COCO数据集介绍以及读取接口代码解读(PyTorch接口)_第1张图片
做目标检测,实例分割,语义分割,采用instances开头的json文件。

COCO提供了API读取数据,但不是直接将数据读入内存,而是读取图像的文件名,目标类别,位置等信息。如要读取图像,还是需要使用PIL或者opencv进行读取。所以,现在的开源论文项目,都是将COCO API再加工,封装为一个适合模型训练和测试的dataset class
接下来,我会先介绍如何安装COCO API(在windows和Ubuntu上),然后简单的介绍一下API。最后写一个读取COCO的数据接口(PyTorch接口)。

安装pycocotools

pycocotools是微软提供的导入coco信息的库。如果你是Ubuntu系统,只需要提前安装好Gcc,Cython,按照COCO github主页的方式按照就行了。至于Windows系统,要预先安装Visual C++ build(C++编译器),建议直接安装VS2017,并安装好C++组件。cython同样是必不可少的。
先把pycocotools从github上clone下来。打开setup.py文件,删除
COCO API的安装,COCO数据集介绍以及读取接口代码解读(PyTorch接口)_第2张图片
然后

python setup.py build_ext install

等几分钟就安装好了,如果控制台没有出现error字样,则安装成功。
可以试着导入一下

import pycocotools

注意,如果你使用的是anaconda包,pycocotools会变成包添加到site-packages里面。如果仅仅是普通的Python环境,在导入之前,需要把cocoapi/PythonAPI添加到环境变量中

try:       # pycocotools 已经加入了全局环境变量中
    from pycocotools.coco import COCO
except ModuleNotFoundError:
    import sys
    # 加载 COCO API 环境
    sys.path.append('D:\API\cocoapi\PythonAPI')
    from pycocotools.coco import COCO

pycocotools几个常用API

  • 构建coco对象, coco = pycocotools.coco.COCO(json_file)
  • coco.getImgIds(self, imgIds=[], catIds=[]) 返回满足条件的图像id
  • coco.imgs.keys() 数据集中所有样本的id号
  • coco.imgToAnns.keys() 数据集中有GT对应的图像样本的id号(用来过滤没有标签的样本)
  • coco.getCatIds 返回含有某一类或者几类的类别id号
  • coco.loadImgs()根据id号,导入对应的图像信息
  • coco.getAnnIds() 根据id号,获得该图像对应的GT的id号
  • coco.loadAnns() 根据 Annotation id号,导入标签信息

基本常用的就是这些了。

from pycocotools.coco import COCO

val_info = r'E:\dataset\coco\annotations\annotations_trainval2017\annotations\instances_val2017.json'
val_image = r'E:\dataset\coco\images\val2017'

coco = COCO(val_info)  # 导入验证集
all_ids = coco.imgs.keys()
print(len(all_ids))
person_id = coco.getCatIds(catNms=['person'])
print(person_id)
person_imgs_id = coco.getImgIds(catIds=person_id)
print(len(person_imgs_id))
###
'''
loading annotations into memory...
Done (t=1.45s)
creating index...
index created!
5000  # 验证集样本总数
[1]  # 人这个类的类别id
2693  # 在验证集中,包含人这个类的图像有2693张
'''
###

读取COCO的PyTorch数据接口

写一个COCOdetection类继承自data.Dataset.

from pycocotools.coco import COCO
import os
import os.path as osp
import sys
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
import cv2
import numpy as np
from pycocotools.coco import COCO

val_info = r'E:\dataset\coco\annotations\annotations_trainval2017\annotations\instances_val2017.json'
val_image = r'E:\dataset\coco\images\val2017'
COCO_CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
                'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
                'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
                'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
                'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
                'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
                'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
                'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
                'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
                'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',*
                'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
                'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
                'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
                'scissors', 'teddy bear', 'hair drier', 'toothbrush')

COCO_LABEL_MAP = { 1:  1,  2:  2,  3:  3,  4:  4,  5:  5,  6:  6,  7:  7,  8:  8,
                   9:  9, 10: 10, 11: 11, 13: 12, 14: 13, 15: 14, 16: 15, 17: 16,
                  18: 17, 19: 18, 20: 19, 21: 20, 22: 21, 23: 22, 24: 23, 25: 24,
                  27: 25, 28: 26, 31: 27, 32: 28, 33: 29, 34: 30, 35: 31, 36: 32,
                  37: 33, 38: 34, 39: 35, 40: 36, 41: 37, 42: 38, 43: 39, 44: 40,
                  46: 41, 47: 42, 48: 43, 49: 44, 50: 45, 51: 46, 52: 47, 53: 48,
                  54: 49, 55: 50, 56: 51, 57: 52, 58: 53, 59: 54, 60: 55, 61: 56,
                  62: 57, 63: 58, 64: 59, 65: 60, 67: 61, 70: 62, 72: 63, 73: 64,
                  74: 65, 75: 66, 76: 67, 77: 68, 78: 69, 79: 70, 80: 71, 81: 72,
                  82: 73, 84: 74, 85: 75, 86: 76, 87: 77, 88: 78, 89: 79, 90: 80}
class COCOAnnotationTransform(object):
    def __init__(self):
        self.label_map = COCO_LABEL_MAP

    def __call__(self, target, width, height):
        scale = np.array([width, height, width, height])
        res = []
        for obj in target:
            if 'bbox' in obj:
                bbox = obj['bbox']
                label_idx = self.label_map[obj['category_id']] - 1
                final_box = list(np.array([bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]])/scale)
                final_box.append(label_idx)
                res += [final_box]  # [xmin, ymin, xmax, ymax, label_idx]
            else:
                print("No bbox found for object ", obj)

        return res  # [[xmin, ymin, xmax, ymax, label_idx], ... ]


class COCODetection(data.Dataset):
    def __init__(self, image_path, info_file, transform=None,
                 target_transform=None, has_gt=True):
        self.root = image_path
        self.coco = COCO(info_file)
        self.ids = list(self.coco.imgToAnns.keys())  # 标签数目 小于样本数目,说明有的图像没有标签

        if len(self.ids) == 0 or not has_gt:  # 如果没有标签或者不需要GT,则直接使用image
            self.ids = list(self.coco.imgs.keys())
        self.transform = transform
        self.target_transform = target_transform

        self.has_gt = has_gt

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
        im, gt, masks, h, w, num_crowds = self.pull_item(index)
        return im, (gt, masks, num_crowds)

    def pull_item(self, index):
        img_id = self.ids[index]
        if self.has_gt:
            ann_ids = self.coco.getAnnIds(imgIds=img_id)
            target = self.coco.loadAnns(ann_ids)
        else:
            target = []
        crowd = [x for x in target if ('iscrowd' in x and x['iscrowd'])]
        target = [x for x in target if not ('iscrowd' in x and x['iscrowd'])]
        num_crowds = len(crowd)

        # This is so we ensure that all crowd annotations are at the end of the array
        target += crowd
        file_name = self.coco.loadImgs(img_id)[0]['file_name']
        path = osp.join(self.root, file_name)
        img = cv2.imread(path)
        height, width, _ = img.shape
        if len(target) > 0: # 这样图像中有不是crowd的目标
            masks = [self.coco.annToMask(obj).reshape(-1) for obj in target]
            masks = np.vstack(masks)
            masks = masks.reshape(-1, height, width)
        if self.target_transform is not None and len(target) > 0:
            target = self.target_transform(target, width, height)
        return torch.from_numpy(img).permute(2, 0, 1), target, masks, height, width, num_crowds


from torch.utils.data import DataLoader
import numpy as np
if __name__=='__main__':
    dataset = COCODetection(val_image, val_info)
    loader = DataLoader(dataset)
    for img, label in loader:
        img = np.uint8(img.squeeze().numpy().transpose(1, 2, 0))
        gt, masks, num_crowds = label
        masks = masks.squeeze(0)
        for m in range(masks.size(0)):
            mask = masks[m].numpy()
            color = np.random.randint(0, 255)
            channel = np.random.randint(0, 3)
            y, x = np.where(mask == 1)
            img[y, x, channel] = color
        cv2.imshow('img', img)
        cv2.waitKey(500)

运行代码:
COCO API的安装,COCO数据集介绍以及读取接口代码解读(PyTorch接口)_第3张图片
COCO API的安装,COCO数据集介绍以及读取接口代码解读(PyTorch接口)_第4张图片

COCO API的安装,COCO数据集介绍以及读取接口代码解读(PyTorch接口)_第5张图片

你可能感兴趣的:(Pytorch)