mmdetection中,在使用cocoAPI测试模型时,打印出每个类别的指标并格式化输出(训练时验证可用):
双击shift搜索 summarize 找到 cocoeval.py 下的 _summarize( ) 函数,结尾添加以下函数:
def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ):
p = self.params
iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>4d} ] = {:<8.3f}'
titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
typeStr = '(AP)' if ap==1 else '(AR)'
iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
if iouThr is None else '{:0.2f}'.format(iouThr)
aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
if ap == 1:
# dimension of precision: [TxRxKxAxM]
s = self.eval['precision']
# IoU
if iouThr is not None:
t = np.where(iouThr == p.iouThrs)[0]
s = s[t]
s = s[:,:,:,aind,mind]
else:
# dimension of recall: [TxKxAxM]
s = self.eval['recall']
if iouThr is not None:
t = np.where(iouThr == p.iouThrs)[0]
s = s[t]
s = s[:,:,aind,mind]
if len(s[s>-1])==0:
mean_s = -1
else:
mean_s = np.mean(s[s>-1])
# print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s))
# 添加的代码=============================================
category_dimension = 1 + int(ap)
if s.shape[category_dimension] > 1:
iStr += ", per category = ["
# iStr += ", per category = {}"
if ap == 1:
mean_axis = (0, 1)
else:
mean_axis = (0,)
per_category_mean_s = np.mean(s, axis=mean_axis).flatten()
print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s), end='')
for i in range(len(per_category_mean_s)):
if i == len(per_category_mean_s) - 1:
end_str = ''
else:
end_str = ', '
print(f'{per_category_mean_s[i]:<9.3f}', end=end_str)
print(']')
# with np.printoptions(precision=3, suppress=True, sign=" ", floatmode="fixed"):
# print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s, per_category_mean_s))
else:
print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s, ""))
# 添加的代码=============================================
return mean_s
再往下翻找到_summarizeDets()函数,并作如下添加:
def _summarizeDets():
# 添加的代码=============================================
if len(self.cocoDt.cats) > 1:
cats = [self.cocoDt.cats[i]['name'] for i in range(len(self.cocoDt.cats))]
iStr_title = ' {:>78} , cats name = ['
print(iStr_title.format('Mean'), end='')
for i in range(len(cats)):
if i == len(cats) - 1:
end_str = ''
else:
end_str = ', '
print(f'{cats[i]:<9}', end=end_str)
print(']')
# 添加的代码=============================================
stats = np.zeros((12,))
stats[0] = _summarize(1)
stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2])
stats[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2])
stats[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2])
stats[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2])
stats[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2])
stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
stats[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2])
stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2])
stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2])
return stats