混淆矩阵:用于多分类模型评估(pytorch)

混淆矩阵(confusion matrix)

  • 1. 混淆矩阵介绍
  • 2. 代码实现
    • 2.1 数据集
    • 2.2 代码:混淆矩阵类
    • 2.3 在验证集上计算相关指标
    • 2.4 结果

1. 混淆矩阵介绍

这里不多说,可参考

  • 混淆矩阵相关概念
  • 调用sklearn库计算混淆矩阵的指标

2. 代码实现

2.1 数据集

此数据集用于多分类任务(检测番茄叶片病虫害)。这里测试的数据集一共1250张图,1000张用于训练,250张用于验证,共分为5个类别。数据集结构如下:
混淆矩阵:用于多分类模型评估(pytorch)_第1张图片
数据集部分图片展示:
混淆矩阵:用于多分类模型评估(pytorch)_第2张图片

2.2 代码:混淆矩阵类

计算accuracy、kappa、precision、recall、specificity

class ConfusionMatrix(object):


    def __init__(self, num_classes: int, labels: list):
        self.matrix = np.zeros((num_classes, num_classes))#初始化混淆矩阵,元素都为0
        self.num_classes = num_classes#类别数量,本例数据集类别为5
        self.labels = labels#类别标签

    def update(self, preds, labels):
        for p, t in zip(preds, labels):#pred为预测结果,labels为真实标签
            self.matrix[p, t] += 1#根据预测结果和真实标签的值统计数量,在混淆矩阵相应位置+1

    def summary(self):#计算指标函数
        # calculate accuracy
        sum_TP = 0
        n = np.sum(self.matrix)
        for i in range(self.num_classes):
            sum_TP += self.matrix[i, i]#混淆矩阵对角线的元素之和,也就是分类正确的数量
        acc = sum_TP / n#总体准确率
        print("the model accuracy is ", acc)
		
		# kappa
        sum_po = 0
        sum_pe = 0
        for i in range(len(self.matrix[0])):
            sum_po += self.matrix[i][i]
            row = np.sum(self.matrix[i, :])
            col = np.sum(self.matrix[:, i])
            sum_pe += row * col
        po = sum_po / n
        pe = sum_pe / (n * n)
        # print(po, pe)
        kappa = round((po - pe) / (1 - pe), 3)
        #print("the model kappa is ", kappa)
        
        # precision, recall, specificity
        table = PrettyTable()#创建一个表格
        table.field_names = ["", "Precision", "Recall", "Specificity"]
        for i in range(self.num_classes):#精确度、召回率、特异度的计算
            TP = self.matrix[i, i]
            FP = np.sum(self.matrix[i, :]) - TP
            FN = np.sum(self.matrix[:, i]) - TP
            TN = np.sum(self.matrix) - TP - FP - FN

            Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
            Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.#每一类准确度
            Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.

            table.add_row([self.labels[i], Precision, Recall, Specificity])
        print(table)
        return str(acc)

    def plot(self):#绘制混淆矩阵
        matrix = self.matrix
        print(matrix)
        plt.imshow(matrix, cmap=plt.cm.Blues)

        # 设置x轴坐标label
        plt.xticks(range(self.num_classes), self.labels, rotation=45)
        # 设置y轴坐标label
        plt.yticks(range(self.num_classes), self.labels)
        # 显示colorbar
        plt.colorbar()
        plt.xlabel('True Labels')
        plt.ylabel('Predicted Labels')
        plt.title('Confusion matrix (acc='+self.summary()+')')

        # 在图中标注数量/概率信息
        thresh = matrix.max() / 2
        for x in range(self.num_classes):
            for y in range(self.num_classes):
                # 注意这里的matrix[y, x]不是matrix[x, y]
                info = int(matrix[y, x])
                plt.text(x, y, info,
                         verticalalignment='center',
                         horizontalalignment='center',
                         color="white" if info > thresh else "black")
        plt.tight_layout()
        plt.show()

2.3 在验证集上计算相关指标

在每个epoch计算一次指标,输出混淆矩阵并可视化

训练过程验证部分代码如下:

		class_indict = config.tomato_DICT
		#tomato_DICT = {'0': 'Bacterial_spot', '1': 'Early_blight', '2': 'healthy', '3': 'Late_blight', '4': 'Leaf_Mold'}
        label = [label for _, label in class_indict.items()]
        confusion = ConfusionMatrix(num_classes=config.NUM_CLASSES, labels=label)
        #实例化混淆矩阵,这里NUM_CLASSES = 5

        with torch.no_grad():
            model.eval()#验证
            for j, (inputs, labels) in enumerate(val_data):
                inputs = inputs.to(device)
                labels = labels.to(device)
                output = model(inputs)#分类网络的输出,分类器用的softmax,即使不使用softmax也不影响分类结果。
                loss = loss_function(output, labels)
                valid_loss += loss.item() * inputs.size(0)
                ret, predictions = torch.max(output.data, 1)#torch.max获取output最大值以及下标,predictions即为预测值(概率最大),这里是获取验证集每个batchsize的预测结果
                #confusion_matrix
                confusion.update(predictions.cpu().numpy(), labels.cpu().numpy())


            confusion.plot()
            confusion.summary()

2.4 结果

训练30个epoch,在第29个epoch取得最好的结果:
混淆矩阵:用于多分类模型评估(pytorch)_第3张图片

混淆矩阵:用于多分类模型评估(pytorch)_第4张图片

真实标签和预测标签在不同位置(x坐标和y坐标)都是可以的,看个人习惯,计算的时候注意就行了

你可能感兴趣的:(Deep_Learning,python,深度学习,机器学习)