网上关于混淆矩阵的代码参差不齐,没找到可用的线程的代码,所以自己尝试写了下
首先它长这样:
怎么看?
Confusion Matrix最广泛的应用应该是分类,比如图中是7分类的真实标签和预测标签的效果。
首先图中表明了纵轴是truth label,横轴是predicted label,那么对于第一行第一个0.60的含义是:本来是angry标签的图,我的模型正确分类成angry的比例是60%,也即是angry这一类模型分类正确的精度只有60%。同时模型将angry分类成了happy的图占比0.04%,其他的以此类推。
注意:因为本身是angry,模型预测成7种类的数量占比。所以每一行的和为100%。
同时对于fear标签,模型分类成fear的占比41%,分类成sad的占比为20%,我们可以认为模型不能很好区分fear和sad两种类别。
先给出代码:
import numpy as np
import matplotlib.pyplot as plt
class DrawConfusionMatrix:
def __init__(self, labels_name):
"""
:param num_classes: 分类数目
"""
self.labels_name = labels_name
self.num_classes = len(labels_name)
self.matrix = np.zeros((self.num_classes, self.num_classes), dtype="float32")
def update(self, predicts, labels):
"""
:param predicts: 一维预测向量,eg:array([0,5,1,6,3,...],dtype=int64)
:param labels: 一维标签向量:eg:array([0,5,0,6,2,...],dtype=int64)
:return:
"""
for predict, label in zip(predicts, labels):
self.matrix[predict, label] += 1
def draw(self):
per_sum = self.matrix.sum(axis=1) # 计算每行的和,用于百分比计算
for i in range(self.num_classes):
self.matrix[i] = (self.matrix[i] / per_sum[i]) # 百分比
plt.imshow(self.matrix, cmap=plt.cm.Blues) # 仅画出颜色格子,没有值
plt.title("Normalized confusion matrix") # title
plt.xlabel("Predict label")
plt.ylabel("Truth label")
plt.yticks(range(self.num_classes), self.labels_name) # y轴标签
plt.xticks(range(self.num_classes), self.labels_name, rotation=45) # x轴标签
for x in range(self.num_classes):
for y in range(self.num_classes):
value = float(format('%.2f' % self.matrix[y, x])) # 数值处理
plt.text(x, y, value, verticalalignment='center', horizontalalignment='center') # 写值
plt.tight_layout() # 自动调整子图参数,使之填充整个图像区域
plt.colorbar() # 色条
plt.savefig('./ConfusionMatrix.png', bbox_inches='tight') # bbox_inches='tight'可确保标签信息显示全
plt.show()
混淆矩阵是将所有数据的label和predict整理而画的,但实际中往往是分成多个iter来推测batch_size个数据,所以需要update()
函数来讲每一次的label和predict值保存进去,模型推理完成后,再调用draw()
函数画出混淆矩阵并保存为图片
给出一个简单的实例:
labels_name=['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral']
drawconfusionmatrix = DrawConfusionMatrix(labels_name=labels_name) # 实例化
for index, (labels, imgs) in enumerate(test_loader):
labels_pd = model(imgs)
predict_np = np.argmax(labels_pd.cpu().detach().numpy(), axis=-1) # array([0,5,1,6,3,...],dtype=int64)
labels_np = labels.numpy() # array([0,5,0,6,2,...],dtype=int64)
drawconfusionmatrix.update(predict_np, labels_np) # 将新批次的predict和label更新(保存)
drawconfusionmatrix.draw() # 根据所有predict和label,画出混淆矩阵