如果对于 COCO 给出的标注信息只是做一些常规场景的使用, 参考 https://github.com/dengdan/coco/blob/master/PythonAPI/pycocoDemo.ipynb 脚本已经完全够用了.
脚本的具体使用方法可以参考我的博客: (一) COCO Python API - 使用篇.
但是对于那些想要进一步挖掘 COCO 数据集的小伙伴来说, 这是远远不够的, 这时就需要直接去源码中找答案了.
另外, YOLO 官方给出的模型是基于 COCO 数据集训练得到的, 但是 YOLO 官网给出的训练方法中, 对于将 COCO 数据集标注格式转换为 YOLO 所需的标注格式并没有给出脚本, 而是直接给出现成的训练数据文件: train.txt 和 标注框数据文件: labels/xxx.txt. 如果用户要在 COCO 数据集的基础上定制自己的数据集, 需要研究源码了.
其实关于 COCO Python API, 核心的文件是在: https://github.com/dengdan/coco/tree/master/PythonAPI/pycocotools/ 目录下, 分别有 coco.py, cocoeval.py 和 mask.py 这三个文件.
以下摘自脚本中对于整个脚本文件的一段描述:
# Microsoft COCO Toolbox. version 2.0
# Data, paper, and tutorials available at: http://mscoco.org/
# Code written by Piotr Dollar and Tsung-Yi Lin, 2014.
这个文件实现了访问 COCO 数据集的接口.可以进行 COCO 标注信息的加载, 解析和可视化操作.
当然, 这个脚本只是提供了常规使用 COCO 数据集的方法, 特殊的使用方法用户可以自行定义.
coco.py 中定义了以下 API:
理解 API 功能的关键是理解这里边包含的几种 ID 信息:图片 ID、类别 ID 和标注信息 ID。
在给定的 API 中, “ann”=annotation, “cat”=category, and “img”=image.接下来就上述提到的 API 一一分析.
COCO 类的构造函数负责加载 json 文件, 并将其中的标注信息建立关联关系. 关联关系存在于以下两个变量中:
这样做的好处就是: 对于给定的图片 id, 就可以快速找到其所有的标注信息; 对于给定的类别 id, 就可以快速找到属于该类的所有图片.
class COCO:
def __init__(self, annotation_file=None):
"""
构造函数, 用来读取关联标注信息.
:param annotation_file (str): 标注文件路径名
:return:
"""
# 加载数据集
# 定义数据成员
self.dataset, self.anns, self.cats, self.imgs = dict(),dict(),dict(),dict()
self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
# 加载标注信息
if not annotation_file == None:
print('loading annotations into memory...')
tic = time.time()
dataset = json.load(open(annotation_file, 'r'))
# dataset 的类型为 dict
assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
print('Done (t={:0.2f}s)'.format(time.time()- tic))
self.dataset = dataset
self.createIndex()
def createIndex(self):
"""
创建索引, 属于构造函数的一部分.
"""
# create index
print('creating index...')
anns, cats, imgs = {}, {}, {}
# imgToAnns: 将图片 id 和标注信息 关联; 给定图片 id, 就可以找到其所有的标注信息
# catToImgs: 将类别 id 和图片 id 关联; 给定类别 id, 就可以找到属于这类的所有图片
imgToAnns,catToImgs = defaultdict(list),defaultdict(list)
if 'annotations' in self.dataset:
for ann in self.dataset['annotations']:
imgToAnns[ann['image_id']].append(ann)
anns[ann['id']] = ann
if 'images' in self.dataset:
for img in self.dataset['images']:
imgs[img['id']] = img
if 'categories' in self.dataset:
for cat in self.dataset['categories']:
cats[cat['id']] = cat
if 'annotations' in self.dataset and 'categories' in self.dataset:
for ann in self.dataset['annotations']:
catToImgs[ann['category_id']].append(ann['image_id'])
print('index created!!!')
# 给几个类成员赋值
self.anns = anns
self.imgs = imgs
self.cats = cats
self.imgToAnns = imgToAnns
self.catToImgs = catToImgs
理解数据关联关系的重点是对 imgToAnns 变量类型的理解. 其中, imgToAnns 的类型为: defaultdict(
. dict 中的键为 id, 值为一个 list, 包含所有的属于这张图的标注信息. 其中 list 的成员类型为字典, 其元素是 “annotations” 域内的一条(或多条)标注信息.
例如:
{ 289343: [
{'id': 1768, 'area': 702.1057499999998,
'segmentation': [[510.66, ... 510.45, 423.01]],
'bbox': [473.07, 395.93, 38.65, 28.67],
'category_id': 18,
'iscrowd': 0,
'image_id': 289343
},
...
]
}
catToImgs 变量类似于 imgToAnns.
def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
"""
获取满足给定条件的标注信息 id. 如果未指定条件, 则返回整个数据集上的标注信息 id
:param imgIds (int 数组) : 通过图片 ids 指定
catIds (int 数组) : 通过类别 ids 指定
areaRng (float 数组) : 通过面积大小的范围, 2 个元素. (e.g. [0 inf]) 指定
iscrowd (boolean) : get anns for given crowd label (False or True)
:return: ids (int 数组) : 满足条件的标注信息的 ids
"""
imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
catIds = catIds if _isArrayLike(catIds) else [catIds]
if len(imgIds) == len(catIds) == len(areaRng) == 0:
anns = self.dataset['annotations']
else:
if not len(imgIds) == 0:
lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns]
anns = list(itertools.chain.from_iterable(lists))
else:
anns = self.dataset['annotations']
anns = anns if len(catIds) == 0
else [ann for ann in anns if ann['category_id'] in catIds]
anns = anns if len(areaRng) == 0
else [ann for ann in anns
if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]]
if not iscrowd == None:
ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]
else:
ids = [ann['id'] for ann in anns]
return ids
def showAnns(self, anns):
"""
显示给定的标注信息. 一般在这个函数调用之前, 已经调用过图像显示函数: plt.imshow()
:param anns (object 数组): 要显示的标注信息
:return: None
"""
if len(anns) == 0:
return 0
if 'segmentation' in anns[0] or 'keypoints' in anns[0]:
datasetType = 'instances'
elif 'caption' in anns[0]:
datasetType = 'captions'
else:
raise Exception('datasetType not supported')
# 目标检测的标注数据
if datasetType == 'instances':
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in anns:
# 为每个标注 mask 生成一个随机颜色
c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]
# 边界 mask 标注信息
if 'segmentation' in ann:
if type(ann['segmentation']) == list:
# polygon
for seg in ann['segmentation']:
# 每个点是多边形的角点, 用 (x, y) 表示
poly = np.array(seg).reshape((int(len(seg)/2), 2))
polygons.append(Polygon(poly))
color.append(c)
else:
# mask
t = self.imgs[ann['image_id']]
if type(ann['segmentation']['counts']) == list:
rle = maskUtils.frPyObjects([ann['segmentation']], t['height'], t['width'])
else:
rle = [ann['segmentation']]
m = maskUtils.decode(rle)
img = np.ones( (m.shape[0], m.shape[1], 3) )
if ann['iscrowd'] == 1:
color_mask = np.array([2.0,166.0,101.0])/255
if ann['iscrowd'] == 0:
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack( (img, m*0.5) ))
# 人体关节点 keypoints 标注信息
if 'keypoints' in ann and type(ann['keypoints']) == list:
# turn skeleton into zero-based index
sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1
kp = np.array(ann['keypoints'])
x = kp[0::3]
y = kp[1::3]
v = kp[2::3]
for sk in sks:
if np.all(v[sk]>0):
plt.plot(x[sk],y[sk], linewidth=3, color=c)
plt.plot(x[v>0], y[v>0],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2)
plt.plot(x[v>1], y[v>1],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2)
p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
ax.add_collection(p)
p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2)
ax.add_collection(p)
elif datasetType == 'captions':
for ann in anns:
print(ann['caption'])
COCO 中的 segmentation 标注信息是多边形, 我们只需要将多边形绘制出来就可以了.
一般在这个函数调用之前, 已经调用过图像显示函数: plt.imshow(). 因为将标注信息显示在图像上才更为直观. 在函数内部会使用 plt.gca() 函数获取当前 figure 的轴, 也就是之前显示的图片所在的 figure, 然后就可以将后续的标注信息显示在这个图上了.