import os
import json
import numpy as np
import matplotlib.pyplot as plt
from prettytable import PrettyTable
class ConfusionMatrix(object):
def __init__(self, num_classes: int, labels: list):
# 手动输入混淆矩阵,以5×5的矩阵为例。
self.matrix = np.array([[592, 0, 0, 0, 0],
[0, 592, 1, 0, 0],
[0, 2, 598, 0, 1],
[0, 1, 0, 599, 0],
[0, 0, 1, 1, 594]])
self.num_classes = num_classes
self.labels = labels
def update(self, preds, labels):
for p, t in zip(preds, labels):
self.matrix[p, t] += 1
def summary(self):
# calculate accuracy
sum_TP = 0
for i in range(self.num_classes):
sum_TP += self.matrix[i, i]
acc = sum_TP / np.sum(self.matrix)
print("模型准确率为:", acc)
# calculate precision, recall, specificity, F1-socre
table = PrettyTable()
table.field_names = ["", "Precision", "Recall", "Specificity","F1-score"]
for i in range(self.num_classes):
TP = self.matrix[i, i]
FP = np.sum(self.matrix[i, :]) - TP
FN = np.sum(self.matrix[:, i]) - TP
TN = np.sum(self.matrix) - TP - FP - FN
Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
F1_score = round(2 * Precision * Recall / (Recall + Precision),3) if Recall + Precision != 0 else 0.
table.add_row([self.labels[i], Precision, Recall, Specificity, F1_score])
print(table)
def plot(self):
matrix = self.matrix
print(matrix)
plt.imshow(matrix, cmap=plt.cm.Blues) # 设置混淆矩阵的颜色
# 设置x轴坐标label
plt.xticks(range(self.num_classes), self.labels, rotation=45)
# 设置y轴坐标label
plt.yticks(range(self.num_classes), self.labels)
# 显示colorbar
plt.colorbar()
plt.xlabel('True Labels')
plt.ylabel('Predicted Labels')
plt.title('Confusion matrix')
# 在图中标注数量/概率信息
thresh = matrix.max() / 2
for x in range(self.num_classes):
for y in range(self.num_classes):
# 注意这里的matrix[y, x]不是matrix[x, y]
info = int(matrix[y, x])
plt.text(x, y, info,
verticalalignment='center',
horizontalalignment='center',
color="white" if info > thresh else "black")
plt.tight_layout()
plt.show()
if __name__ == '__main__':
# read class_indict
# class_indices.json 文件中存放的是分类的类别, ./class_indices.json 表示当前目录下的class_indices.json文件
json_label_path = './class_indices.json'
assert os.path.exists(json_label_path), "cannot find {} file".format(json_label_path)
json_file = open(json_label_path, 'r')
class_indict = json.load(json_file)
labels = [label for _, label in class_indict.items()]
# num_classes为分类的类别
confusion = ConfusionMatrix(num_classes=5, labels=labels)
confusion.plot()
confusion.summary()
其中json格式的文件如下:
{
"0": "13",
"1": "18",
"2": "23",
"3": "28",
"4": "33"
}
可以按照以上格式(以5分类为例),先写在记事本上再更改后缀名
*注意最后一个后面没有 “,”
没有扩展名的看下面这个图给它调出来↓
代码部分参考如下:
参考文献:使用pytorch和tensorflow计算分类模型的混淆矩阵_哔哩哔哩_bilibili