pycocotools.coco源码分析

coco.py的主要目的是加载标注文件,通过各类定义的API函数能够方便的获取categories、annotation、image的id和具体的内容。

定义了以下的API functions:

#  COCO       - COCO api class that loads COCO annotation file and prepare data structures. 
#  decodeMask - Decode binary mask M encoded via run-length encoding.                       
#  encodeMask - Encode binary mask M using run-length encoding.                             
#  getAnnIds  - Get ann ids that satisfy given filter conditions.                           
#  getCatIds  - Get cat ids that satisfy given filter conditions.                           
#  getImgIds  - Get img ids that satisfy given filter conditions.                           
#  loadAnns   - Load anns with the specified ids.                                           
#  loadCats   - Load cats with the specified ids.                                           
#  loadImgs   - Load imgs with the specified ids.                                           
#  annToMask  - Convert segmentation in an annotation to binary mask.                       
#  showAnns   - Display the specified annotations.                                          
#  loadRes    - Load algorithm results and create API for accessing them.                   
#  download   - Download COCO images from mscoco.org server.                                
# Throughout the API "ann"=annotation, "cat"=category, and "img"=image.                     
                         
  • COCO: COCO类加载COCO annotation文件并且准备数据结构,例如可以通过COCO.dataset或者COCO.imgs来调取原annotation .json文件下的数据内容(一般是字典格式)
  • getAnnIds(Cat&Img): 返回满足给定过滤条件的ann的id
  • LoadAnns(Cat&Img): 加载给定id的ann文件
  • annToMask: 将 segmentation 标注信息转换为二值 mask
  • showAnns: 显示指定 ids 的标注信息到对应的图片上
  • loadRes: 加载算法的结果并创建可用于访问数据的 API 
  • download: 从 mscoco.org 服务器下载 COCO 图片数据集

1. _isArrayLike()函数定义

def _isArrayLike(obj):                                          
    return hasattr(obj, '__iter__') and hasattr(obj, '__len__') 

如果对象具备'__iter__'和'__len__'属性,即可迭代且有长度,例如list,那么返回true,否则false

2. COCO类初始化(ann文件读取和idx构建)

class COCO:                                                                                             
    def __init__(self, annotation_file=None):                                                           
        """                                                                                             
        Constructor of Microsoft COCO helper class for reading and visualizing annotations.             
        :param annotation_file (str): location of annotation file                                       
        :param image_folder (str): location to the folder that hosts images.                            
        :return:                                                                                        
        """                                                                                             
        # load dataset                                                                                  
        self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()                        
        self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)                           
        if not annotation_file == None:                                                                 
            print('loading annotations into memory...')                                                 
            tic = time.time()                                                                           
            with open(annotation_file, 'r') as f:                                                       
                dataset = json.load(f)                                                                  
            assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset)) 
            print('Done (t={:0.2f}s)'.format(time.time()- tic))                                         
            self.dataset = dataset                                                                      
            self.createIndex()                                                                          
                                                                                                        
    def createIndex(self):                                                                              
        # create index                                                                                  
        print('creating index...')                                                                      
        anns, cats, imgs = {}, {}, {}                                                                   
        imgToAnns,catToImgs = defaultdict(list),defaultdict(list)                                       
        if 'annotations' in self.dataset:                                                               
            for ann in self.dataset['annotations']:                                                     
                imgToAnns[ann['image_id']].append(ann)                                                  
                anns[ann['id']] = ann                                                                   
                                                                                                        
        if 'images' in self.dataset:                                                                    
            for img in self.dataset['images']:                                                          
                imgs[img['id']] = img                                                                   
                                                                                                        
        if 'categories' in self.dataset:                                                                
            for cat in self.dataset['categories']:                                                      
                cats[cat['id']] = cat                                                                   
                                                                                                        
        if 'annotations' in self.dataset and 'categories' in self.dataset:                              
            for ann in self.dataset['annotations']:                                                     
                catToImgs[ann['category_id']].append(ann['image_id'])                                   
                                                                                                        
        print('index created!')                                                                         
                                                                                                        
        # create class members                                                                          
        self.anns = anns                                                                                
        self.imgToAnns = imgToAnns                                                                      
        self.catToImgs = catToImgs                                                                      
        self.imgs = imgs                                                                                
        self.cats = cats                                                                                
                                                                                                        
  • self.dataset = json.load(annfile_path),类型是字典,把ann文件中的整个大字典读取了进来;
  • anns={ann_id:ann}, imgToAnns={img_id:[ann...该照片包含的所有标注文件]}
  • 同理imgs={img_id:img},cats={cat_id:cat}, 相当于对于整个dataset(标注文件大字典)的一个分块,把ann\cat\img都按照对应的id做一个分块
  • catToImgs={cat_id:[img_id...该类别出现过的所有照片id]},为后续给定cat_id返回所有包含该cat的照片id做准备

3. getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None)

def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):                                                         
    """                                                                                                                      
    Get ann ids that satisfy given filter conditions. default skips that filter                                              
    :param imgIds  (int array)     : get anns for given imgs                                                                 
           catIds  (int array)     : get anns for given cats                                                                 
           areaRng (float array)   : get anns for given area range (e.g. [0 inf])                                            
           iscrowd (boolean)       : get anns for given crowd label (False or True)                                          
    :return: ids (int array)       : integer array of ann ids                                                                
    """                                                                                                                      
    imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]                                                                    
    catIds = catIds if _isArrayLike(catIds) else [catIds]                                                                    
                                                                                                                             
    if len(imgIds) == len(catIds) == len(areaRng) == 0:                                                                      
        anns = self.dataset['annotations']                                                                                   
    else:                                                                                                                    
        if not len(imgIds) == 0:                                                                                             
            lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns]                                   
            anns = list(itertools.chain.from_iterable(lists))                                                                
        else:                                                                                                                
            anns = self.dataset['annotations']                                                                               
        anns = anns if len(catIds)  == 0 else [ann for ann in anns if ann['category_id'] in catIds]                          
        anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]] 
    if not iscrowd == None:                                                                                                  
        ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]                                                       
    else:                                                                                                                    
        ids = [ann['id'] for ann in anns]                                                                                    
    return ids                                                                                                               
                                                                                                                             

给定imgIds/catIds/areaRng/iscrowd这几个condition,来筛选出满足条件的ann文件的id,返回ids=[id...],由于annotation包含了img_id、cat_id、area大小、iscrowd这几个信息,所以能够去过滤去满足条件的ann_id。如果什么过滤条件都没给出,那就直接返回dataset中所有的ann文件的id。

4. getCatIds(self, catNms=[], supNms=[], catIds=[])

def getCatIds(self, catNms=[], supNms=[], catIds=[]):                                                 
    """                                                                                               
    filtering parameters. default skips that filter.                                                  
    :param catNms (str array)  : get cats for given cat names                                         
    :param supNms (str array)  : get cats for given supercategory names                               
    :param catIds (int array)  : get cats for given cat ids                                           
    :return: ids (int array)   : integer array of cat ids                                             
    """                                                                                               
    catNms = catNms if _isArrayLike(catNms) else [catNms]                                             
    supNms = supNms if _isArrayLike(supNms) else [supNms]                                             
    catIds = catIds if _isArrayLike(catIds) else [catIds]                                             
                                                                                                      
    if len(catNms) == len(supNms) == len(catIds) == 0:                                                
        cats = self.dataset['categories']                                                             
    else:                                                                                             
        cats = self.dataset['categories']                                                             
        cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name']          in catNms]  
        cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms]  
        cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id']            in catIds]  
    ids = [cat['id'] for cat in cats]                                                                 
    return ids                                                                                        

类似于getAnnIds()函数,getCatIds()过滤条件为catNms、supNms、catIds,返回对应cat/supcat名称的cat_id,如果给定了catIds那就更为直接,就相当于复制了一下。如果没有给定任何参数,那么就返回dataset的”categories“中包含的所有cat_id。

5. getImgIds(self, imgIds=[], catIds=[])

def getImgIds(self, imgIds=[], catIds=[]):                        
    '''                                                           
    Get img ids that satisfy given filter conditions.             
    :param imgIds (int array) : get imgs for given ids            
    :param catIds (int array) : get imgs with all given cats      
    :return: ids (int array)  : integer array of img ids          
    '''                                                           
    imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]         
    catIds = catIds if _isArrayLike(catIds) else [catIds]         
                                                                  
    if len(imgIds) == len(catIds) == 0:                           
        ids = self.imgs.keys()                                    
    else:                                                         
        ids = set(imgIds)                                         
        for i, catId in enumerate(catIds):                        
            if i == 0 and len(ids) == 0:                          
                ids = set(self.catToImgs[catId])                  
            else:                                                 
                ids &= set(self.catToImgs[catId])                 
    return list(ids)                                              

给定的过滤条件为catid与imgid,imgid就很直接,相当于又直接复制一遍,catid则利用在初始化阶段初始化的self.catToImgs,返回出现了该种类物体的所有img_id。

6. loadAnns/Cats/Imgs()

def loadAnns(self, ids=[]):                                            
    """                                                                
    Load anns with the specified ids.                                  
    :param ids (int array)       : integer ids specifying anns         
    :return: anns (object array) : loaded ann objects                  
    """                                                                
    if _isArrayLike(ids):                                              
        return [self.anns[id] for id in ids]                           
    elif type(ids) == int:                                             
        return [self.anns[ids]]                                        
                                                                       
def loadCats(self, ids=[]):                                            
    """                                                                
    Load cats with the specified ids.                                  
    :param ids (int array)       : integer ids specifying cats         
    :return: cats (object array) : loaded cat objects                  
    """                                                                
    if _isArrayLike(ids):                                              
        return [self.cats[id] for id in ids]                           
    elif type(ids) == int:                                             
        return [self.cats[ids]]                                        
                                                                       
def loadImgs(self, ids=[]):                                            
    """                                                                
    Load anns with the specified ids.                                  
    :param ids (int array)       : integer ids specifying img          
    :return: imgs (object array) : loaded img objects                  
    """                                                                
    if _isArrayLike(ids):                                              
        return [self.imgs[id] for id in ids]                           
    elif type(ids) == int:                                             
        return [self.imgs[ids]]                                        

给定ids=[id...],返回对应在dataset中”categories“or”annotations“or”images“下对应id的字典,返回值也是一个list=[{annotation_1},...]

7. showAnns(self, anns, draw_bbox=False)

def showAnns(self, anns, draw_bbox=False):
    """
    Display the specified annotations.
    :param anns (array of object): annotations to display
    :return: None
    """
    if len(anns) == 0:
        return 0
    if 'segmentation' in anns[0] or 'keypoints' in anns[0]:
        datasetType = 'instances'
    elif 'caption' in anns[0]:
        datasetType = 'captions'
    else:
        raise Exception('datasetType not supported')
    if datasetType == 'instances':
        ax = plt.gca()  # 获取当前figure的轴,将后续的标注信息显示在之前打印过的原图上
        ax.set_autoscale_on(False)
        polygons = []
        color = []
        for ann in anns:
            c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
            if 'segmentation' in ann:
                if type(ann['segmentation']) == list:
                    # polygon
                    for seg in ann['segmentation']:
                        poly = np.array(seg).reshape((int(len(seg) / 2), 2)) # 转换成(x,y)格式
                        polygons.append(Polygon(poly))
                        color.append(c)
                else:
                    # mask
                    t = self.imgs[ann['image_id']]
                    if type(ann['segmentation']['counts']) == list:
                        rle = maskUtils.frPyObjects([ann['segmentation']], t['height'], t['width'])
                    else:
                        rle = [ann['segmentation']]
                    m = maskUtils.decode(rle)
                    img = np.ones((m.shape[0], m.shape[1], 3))
                    if ann['iscrowd'] == 1:
                        color_mask = np.array([2.0, 166.0, 101.0]) / 255
                    if ann['iscrowd'] == 0:
                        color_mask = np.random.random((1, 3)).tolist()[0]
                    for i in range(3):
                        img[:, :, i] = color_mask[i]
                    ax.imshow(np.dstack((img, m * 0.5)))
            if 'keypoints' in ann and type(ann['keypoints']) == list:
                # turn skeleton into zero-based index
                sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton']) - 1
                kp = np.array(ann['keypoints'])
                x = kp[0::3]
                y = kp[1::3]
                v = kp[2::3]
                for sk in sks:
                    if np.all(v[sk] > 0):
                        plt.plot(x[sk], y[sk], linewidth=3, color=c)
                plt.plot(x[v > 0], y[v > 0], 'o', markersize=8, markerfacecolor=c, markeredgecolor='k',
                         markeredgewidth=2)
                plt.plot(x[v > 1], y[v > 1], 'o', markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2)

            if draw_bbox:
                [bbox_x, bbox_y, bbox_w, bbox_h] = ann['bbox']
                poly = [[bbox_x, bbox_y], [bbox_x, bbox_y + bbox_h], [bbox_x + bbox_w, bbox_y + bbox_h],
                        [bbox_x + bbox_w, bbox_y]]
                np_poly = np.array(poly).reshape((4, 2))
                polygons.append(Polygon(np_poly))
                color.append(c)

        p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
        ax.add_collection(p)
        p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2)
        ax.add_collection(p)
    elif datasetType == 'captions':
        for ann in anns:
            print(ann['caption'])

后续有需求再进行补充

你可能感兴趣的:(COCO)