(二) COCO Python API - 源码分析篇

如果对于 COCO 给出的标注信息只是做一些常规场景的使用, 参考 https://github.com/dengdan/coco/blob/master/PythonAPI/pycocoDemo.ipynb 脚本已经完全够用了.

脚本的具体使用方法可以参考我的博客: (一) COCO Python API - 使用篇.

但是对于那些想要进一步挖掘 COCO 数据集的小伙伴来说, 这是远远不够的, 这时就需要直接去源码中找答案了.

另外, YOLO 官方给出的模型是基于 COCO 数据集训练得到的, 但是 YOLO 官网给出的训练方法中, 对于将 COCO 数据集标注格式转换为 YOLO 所需的标注格式并没有给出脚本, 而是直接给出现成的训练数据文件: train.txt 和 标注框数据文件: labels/xxx.txt. 如果用户要在 COCO 数据集的基础上定制自己的数据集, 需要研究源码了.

其实关于 COCO Python API, 核心的文件是在: https://github.com/dengdan/coco/tree/master/PythonAPI/pycocotools/ 目录下, 分别有 coco.py, cocoeval.py 和 mask.py 这三个文件.


1. coco.py 脚本


以下摘自脚本中对于整个脚本文件的一段描述:

# Microsoft COCO Toolbox.      version 2.0
# Data, paper, and tutorials available at:  http://mscoco.org/
# Code written by Piotr Dollar and Tsung-Yi Lin, 2014.

这个文件实现了访问 COCO 数据集的接口.可以进行 COCO 标注信息的加载, 解析和可视化操作.  
  
当然, 这个脚本只是提供了常规使用 COCO 数据集的方法, 特殊的使用方法用户可以自行定义.   

coco.py 中定义了以下 API:

  • COCO 类 - 用于加载 COCO 标注文件并准备所需的数据结构.
  • decodeMask - Decode binary mask M encoded via run-length encoding.
  • encodeMask - Encode binary mask M using run-length encoding.
  • getAnnIds - 获取满足条件的标注信息的 ids.
  • getCatIds - 获取满足条件的类别的 ids.
  • getImgIds - 获取满足条件的图片的 ids.
  • loadAnns - 加载指定 ids 对应的标注信息.
  • loadCats - 加载指定 ids 对应的类别.
  • loadImgs - 加载指定 ids 对应的图片.
  • annToMask - 将 segmentation 标注信息转换为二值 mask.
  • showAnns - 显示指定 ids 的标注信息到对应的图片上.
  • loadRes - 加载算法的结果并创建可用于访问数据的 API .
  • download - 从 mscoco.org 服务器下载 COCO 图片数据集.

理解 API 功能的关键是理解这里边包含的几种 ID 信息:图片 ID、类别 ID 和标注信息 ID。

在给定的 API 中, “ann”=annotation, “cat”=category, and “img”=image.接下来就上述提到的 API 一一分析.

1) COCO 类的构造函数

COCO 类的构造函数负责加载 json 文件, 并将其中的标注信息建立关联关系. 关联关系存在于以下两个变量中:

  • imgToAnns: 将图片 id 和标注信息 关联;
  • catToImgs: 将类别 id 和图片 id 关联.

这样做的好处就是: 对于给定的图片 id, 就可以快速找到其所有的标注信息; 对于给定的类别 id, 就可以快速找到属于该类的所有图片.

class COCO:
    def __init__(self, annotation_file=None):
        """
        构造函数, 用来读取关联标注信息.
        :param annotation_file (str): 标注文件路径名
        :return:
        """
        # 加载数据集

        # 定义数据成员
        self.dataset, self.anns, self.cats, self.imgs = dict(),dict(),dict(),dict()
        self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)

        # 加载标注信息
        if not annotation_file == None:
            print('loading annotations into memory...')
            tic = time.time()
            dataset = json.load(open(annotation_file, 'r'))
            # dataset 的类型为 dict
            assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
            print('Done (t={:0.2f}s)'.format(time.time()- tic))
            self.dataset = dataset
            self.createIndex()

    def createIndex(self):
        """
        创建索引, 属于构造函数的一部分. 
        """

        # create index
        print('creating index...')

        anns, cats, imgs = {}, {}, {}
        # imgToAnns: 将图片 id 和标注信息 关联; 给定图片 id, 就可以找到其所有的标注信息
        # catToImgs: 将类别 id 和图片 id 关联; 给定类别 id, 就可以找到属于这类的所有图片
        imgToAnns,catToImgs = defaultdict(list),defaultdict(list)

        if 'annotations' in self.dataset:
            for ann in self.dataset['annotations']:
                imgToAnns[ann['image_id']].append(ann)
                anns[ann['id']] = ann

        if 'images' in self.dataset:
            for img in self.dataset['images']:
                imgs[img['id']] = img

        if 'categories' in self.dataset:
            for cat in self.dataset['categories']:
                cats[cat['id']] = cat

        if 'annotations' in self.dataset and 'categories' in self.dataset:
            for ann in self.dataset['annotations']:
                catToImgs[ann['category_id']].append(ann['image_id'])

        print('index created!!!')

        # 给几个类成员赋值
        self.anns = anns
        self.imgs = imgs
        self.cats = cats
        self.imgToAnns = imgToAnns
        self.catToImgs = catToImgs

理解数据关联关系的重点是对 imgToAnns 变量类型的理解. 其中, imgToAnns 的类型为: defaultdict(). dict 中的键为 id, 值为一个 list, 包含所有的属于这张图的标注信息. 其中 list 的成员类型为字典, 其元素是 “annotations” 域内的一条(或多条)标注信息.

例如:

{ 289343: [
            {'id': 1768, 'area': 702.1057499999998, 
             'segmentation': [[510.66, ... 510.45, 423.01]], 
             'bbox': [473.07, 395.93, 38.65, 28.67], 
             'category_id': 18, 
             'iscrowd': 0, 
             'image_id': 289343
            },
            ...
        ]
}

catToImgs 变量类似于 imgToAnns.

2) getAnnIds() - 获取标注信息的 ids

 def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
   """
   获取满足给定条件的标注信息 id. 如果未指定条件, 则返回整个数据集上的标注信息 id
   :param imgIds  (int 数组)     : 通过图片 ids 指定
           catIds  (int 数组)     : 通过类别 ids 指定
           areaRng (float 数组)   : 通过面积大小的范围, 2 个元素. (e.g. [0 inf]) 指定
           iscrowd (boolean)      : get anns for given crowd label (False or True)
   :return: ids (int 数组)       : 满足条件的标注信息的 ids
   """
   imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
   catIds = catIds if _isArrayLike(catIds) else [catIds]

   if len(imgIds) == len(catIds) == len(areaRng) == 0:
       anns = self.dataset['annotations']
   else:
       if not len(imgIds) == 0:
           lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns]
           anns = list(itertools.chain.from_iterable(lists))
       else:
           anns = self.dataset['annotations']

       anns = anns if len(catIds)  == 0 
		       else [ann for ann in anns if ann['category_id'] in catIds]
       anns = anns if len(areaRng) == 0 
		       else [ann for ann in anns 
			       if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]]

   if not iscrowd == None:
       ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]
   else:
       ids = [ann['id'] for ann in anns]

   return ids

3) showAnns() - 显示给定的标注信息

def showAnns(self, anns):
    """
    显示给定的标注信息. 一般在这个函数调用之前, 已经调用过图像显示函数: plt.imshow() 
    :param anns (object 数组): 要显示的标注信息
    :return: None
    """
    if len(anns) == 0:
        return 0
    if 'segmentation' in anns[0] or 'keypoints' in anns[0]:
        datasetType = 'instances'
    elif 'caption' in anns[0]:
        datasetType = 'captions'
    else:
        raise Exception('datasetType not supported')
        
    # 目标检测的标注数据
    if datasetType == 'instances':
        ax = plt.gca()
        ax.set_autoscale_on(False)
        polygons = []
        color = []
        for ann in anns:
            # 为每个标注 mask 生成一个随机颜色
            c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]
            # 边界 mask 标注信息
            if 'segmentation' in ann:
                if type(ann['segmentation']) == list:
                    # polygon
                    for seg in ann['segmentation']:
                        # 每个点是多边形的角点, 用 (x, y) 表示
                        poly = np.array(seg).reshape((int(len(seg)/2), 2))
                        polygons.append(Polygon(poly))
                        color.append(c)
                else:
                    # mask
                    t = self.imgs[ann['image_id']]
                    if type(ann['segmentation']['counts']) == list:
                        rle = maskUtils.frPyObjects([ann['segmentation']], t['height'], t['width'])
                    else:
                        rle = [ann['segmentation']]
                    m = maskUtils.decode(rle)
                    img = np.ones( (m.shape[0], m.shape[1], 3) )
                    if ann['iscrowd'] == 1:
                        color_mask = np.array([2.0,166.0,101.0])/255
                    if ann['iscrowd'] == 0:
                        color_mask = np.random.random((1, 3)).tolist()[0]
                    for i in range(3):
                        img[:,:,i] = color_mask[i]
                    ax.imshow(np.dstack( (img, m*0.5) ))
             # 人体关节点 keypoints 标注信息
            if 'keypoints' in ann and type(ann['keypoints']) == list:
                # turn skeleton into zero-based index
                sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1
                kp = np.array(ann['keypoints'])
                x = kp[0::3]
                y = kp[1::3]
                v = kp[2::3]
                for sk in sks:
                    if np.all(v[sk]>0):
                        plt.plot(x[sk],y[sk], linewidth=3, color=c)
                plt.plot(x[v>0], y[v>0],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2)
                plt.plot(x[v>1], y[v>1],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2)
        p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
        ax.add_collection(p)
        p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2)
        ax.add_collection(p)
    elif datasetType == 'captions':
        for ann in anns:
            print(ann['caption'])

COCO 中的 segmentation 标注信息是多边形, 我们只需要将多边形绘制出来就可以了.

一般在这个函数调用之前, 已经调用过图像显示函数: plt.imshow(). 因为将标注信息显示在图像上才更为直观. 在函数内部会使用 plt.gca() 函数获取当前 figure 的轴, 也就是之前显示的图片所在的 figure, 然后就可以将后续的标注信息显示在这个图上了.

你可能感兴趣的:(深度学习)