Python手动输入混淆矩阵,并计算混淆矩阵的准确率、精确率、召回率、特异度、F1-score

import os
import json
import numpy as np
import matplotlib.pyplot as plt
from prettytable import PrettyTable

class ConfusionMatrix(object):
    def __init__(self, num_classes: int, labels: list):
        # 手动输入混淆矩阵,以5×5的矩阵为例。
        self.matrix = np.array([[592, 0, 0, 0, 0],
                                [0, 592, 1, 0, 0],
                                [0, 2, 598, 0, 1],
                                [0, 1, 0, 599, 0],
                                [0, 0, 1, 1, 594]])
        self.num_classes = num_classes
        self.labels = labels

    def update(self, preds, labels):
        for p, t in zip(preds, labels):
            self.matrix[p, t] += 1

    def summary(self):
        # calculate accuracy
        sum_TP = 0
        for i in range(self.num_classes):
            sum_TP += self.matrix[i, i]
        acc = sum_TP / np.sum(self.matrix)
        print("模型准确率为:", acc)

        # calculate precision, recall, specificity, F1-socre
        table = PrettyTable()
        table.field_names = ["", "Precision", "Recall", "Specificity","F1-score"]
        for i in range(self.num_classes):
            TP = self.matrix[i, i]
            FP = np.sum(self.matrix[i, :]) - TP
            FN = np.sum(self.matrix[:, i]) - TP
            TN = np.sum(self.matrix) - TP - FP - FN

            Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
            Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
            Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
            F1_score = round(2 * Precision * Recall / (Recall + Precision),3) if Recall + Precision != 0 else 0.

            table.add_row([self.labels[i], Precision, Recall, Specificity, F1_score])
        print(table)

    def plot(self):
        matrix = self.matrix
        print(matrix)
        plt.imshow(matrix, cmap=plt.cm.Blues)  # 设置混淆矩阵的颜色

        # 设置x轴坐标label
        plt.xticks(range(self.num_classes), self.labels, rotation=45)
        # 设置y轴坐标label
        plt.yticks(range(self.num_classes), self.labels)
        # 显示colorbar
        plt.colorbar()
        plt.xlabel('True Labels')
        plt.ylabel('Predicted Labels')
        plt.title('Confusion matrix')

        # 在图中标注数量/概率信息
        thresh = matrix.max() / 2
        for x in range(self.num_classes):
            for y in range(self.num_classes):
                # 注意这里的matrix[y, x]不是matrix[x, y]
                info = int(matrix[y, x])
                plt.text(x, y, info,
                         verticalalignment='center',
                         horizontalalignment='center',
                         color="white" if info > thresh else "black")
        plt.tight_layout()
        plt.show()


if __name__ == '__main__':
    # read class_indict
    # class_indices.json 文件中存放的是分类的类别, ./class_indices.json 表示当前目录下的class_indices.json文件
    json_label_path = './class_indices.json'
    assert os.path.exists(json_label_path), "cannot find {} file".format(json_label_path)
    json_file = open(json_label_path, 'r')
    class_indict = json.load(json_file)

    labels = [label for _, label in class_indict.items()]

    # num_classes为分类的类别
    confusion = ConfusionMatrix(num_classes=5, labels=labels)

    confusion.plot()
    confusion.summary()

其中json格式的文件如下:

{
    "0": "13",
    "1": "18",
    "2": "23",
    "3": "28",
    "4": "33"
}

可以按照以上格式(以5分类为例),先写在记事本上再更改后缀名

*注意最后一个后面没有 “,” 

没有扩展名的看下面这个图给它调出来↓

Python手动输入混淆矩阵,并计算混淆矩阵的准确率、精确率、召回率、特异度、F1-score_第1张图片

代码部分参考如下:

 参考文献:使用pytorch和tensorflow计算分类模型的混淆矩阵_哔哩哔哩_bilibili

你可能感兴趣的:(python,矩阵,numpy)