混肴矩阵之精度、召回等评估指标代码实现

import numpy as np
import matplotlib.pyplot as plt
from prettytable import PrettyTable

class Confusion_Matrix(object):
    """
    根据一整个训练的数据来制作混肴矩阵
    """
    def __init__(self, num_classes: int):
        self.num_classes = num_classes
        self.matrix = np.zeros([num_classes, num_classes])
        self.labels = list(range(num_classes))
        self.field_names = ["Precision", "Recall", "Specificity","F1_score","acc"]
        self.infometrics = np.zeros([num_classes, len(self.field_names)])

    def Matrix_update(self, preds, labels):
        if preds.shape == labels.shape and len(preds.shape) ==2:
            preds = np.argmax(preds,axis=1)
            labels = np.argmax(labels,axis=1)
            for i, j in zip(preds, labels):
                self.matrix[i, j] += 1

    def Matrix_summary(self):

        sum_TP = 0
        for i in range(self.num_classes):
            sum_TP += self.matrix[i, i]
        accuracy = sum_TP / np.sum(self.matrix)

        # "精确率", "召回率", "特异度"
        table = PrettyTable()
        table.field_names = ["num_classes", "Precision", "Recall", "Specificity","F1_score","acc"]
        # num_classes 数据种类名称、Precision 精确率、Recall 召回率、Specificity 特异度

        avaerage_Precision = []
        avaerage_Recall = []
        for i in range(self.num_classes):
            TP = self.matrix[i, i]
            FP = np.sum(self.matrix[i, :]) - TP
            FN = np.sum(self.matrix[:, i]) - TP
            TN = np.sum(self.matrix) - TP - FP - FN
            Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
            avaerage_Precision.append(Precision)
            Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
            avaerage_Recall.append(Recall)
            Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
            F1_score = round(2*Precision*Recall / (Precision+Recall), 3) if Precision+Recall else 0.
            table.add_row([self.labels[i], Precision, Recall, Specificity,F1_score,accuracy])
            self.infometrics[i,:] = [Precision, Recall, Specificity,F1_score,accuracy]
        print(table)
        return np.average(self.infometrics,axis=0)


    def Matrix_plot(self):
        matrix = self.matrix
        plt.imshow(matrix, cmap=plt.cm.Reds)
        plt.xticks(range(self.num_classes), self.labels, rotation=45)
        plt.yticks(range(self.num_classes), self.labels)
        plt.colorbar()
        plt.xlabel('真实类别')
        plt.ylabel('预测类别')
        plt.title('混淆矩阵')
        plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置汉语显示
        plt.rcParams['axes.unicode_minus'] = False

        # 在图中标注数量/概率信息
        thresh = matrix.max() / 2
        for x in range(self.num_classes):
            for y in range(self.num_classes):
                fin_matrix = int(matrix[y, x])
                plt.text(x, y, fin_matrix,
                         verticalalignment='center',
                         horizontalalignment='center',
                         color="white" if fin_matrix > thresh else "black")
        plt.tight_layout()
        plt.savefig('./混淆矩阵.jpg')  # 保存图片到当前文件夹路径下,图片格式为jpg,也可以修改成其他格式,例如png等,根据需要自行修改即可
        plt.show()

你可能感兴趣的:(矩阵,python,numpy)