基于MMdetection框架的目标检测研究-6.绘制混淆矩阵

文章背景:

当我们训练完模型后,我们需要用训练后的模型对正负样本图片进行目标检测测试,这时候我们需要算模型在新的数据集上的检测效果(精度、过杀率、漏检率,准确度等),这时候使用测试后的结果绘制成混淆矩阵,可以很方便的帮助我们呈现和理解模型的泛华能力。

核心代码:

# -*- coding=utf-8 -*-
'''
功能说明:根据已有的分类数据,绘制相应的混淆矩阵,便于统计
过杀率和漏检率
'''
import numpy as np
import matplotlib.pyplot as plt
# 修改类别列表中的数据和矩阵中数据可以绘制多类混淆矩阵
classes = ['OK ','NG']
confusion_matrix = np.array([(20,5),(5,55)],dtype=np.float64)
plt.imshow(confusion_matrix, interpolation='nearest', cmap=plt.cm.Oranges)  #按照像素显示出矩阵
plt.title('confusion_matrix')
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes)
plt.yticks(tick_marks, classes)
thresh = confusion_matrix.max() / 2.
#iters = [[i,j] for i in range(len(classes)) for j in range((classes))]
#ij配对,遍历矩阵迭代器
iters = np.reshape([[[i,j] for j in range(len(classes))] for i in range(len(classes))],(confusion_matrix.size,2))
for i, j in iters:
    plt.text(j, i, format(confusion_matrix[i, j]))   #显示对应的数字
plt.ylabel('Real label')
plt.xlabel('Prediction')
plt.tight_layout()
#plt.show()
# 保存每次生成的图像
f = plt.gcf()  #获取当前图像
f.savefig(r'./{}.png'.format('result'))# 一定要放到plt.show()前面,否则保存图像为空白
plt.show()#plt.show() 后实际上已经创建了一个新的空白的图片
#f.clear()  #释放内存,迭代保存的时候,plt.plot()会出现多根线在一张图叠加,可以加这句话
print('混淆矩阵图像绘制结束并保存在当前路径下。')

结果显示如下,并在代码路径下保存生成结果:

基于MMdetection框架的目标检测研究-6.绘制混淆矩阵_第1张图片

混淆矩阵图分析: 

该混淆矩阵结果图表示的是,OK实际测试样本有25个,预测为OK的样本有20个,预测为NG的样本有5个。NG实际测试样本有60个,预测为NG的有55个,预测为OK的样本有5个。

你可能感兴趣的:(MMdetection,python,目标检测,混淆矩阵,python,MMdetection)