视觉分类任务中,Confusion Matrix (混淆矩阵) 的绘制

Confusion Matrix (混淆矩阵) 的绘制

import cv2
import time
import numpy as np
from pathlib import Path
from datetime import datetime
import matplotlib.pyplot as plt

import torch
import torchvision
from torchvision import datasets, transforms

from models.drp import DRP
from models.trformer_dual import TRFormer as trformer_dual

class Config:
    image_resize = 256
    image_crop = 224
    batch_size = 32                 # 64 for MIT-Indoor, 128 for others
    backbone = "resnet50"           # vgg19 resnet18 resnet50 resnet101 densenet161
    dataset_name = "dtd_t-SNE"     #  "FMD" "dtd-r1.0.1" "4D_Light" "MIT-Indoor"
    data_dir = Path(f"data/{dataset_name}/splits/split_1")
    output_dir = Path(f"outputs/confusion_matrix")
    checkpoint = "outputs/drp_181.pth.tar"



DataTransforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(Config.image_crop),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]),
    'test' : transforms.Compose([
        transforms.Resize(Config.image_resize),
        transforms.CenterCrop(Config.image_crop),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}


class DrawConfusionMatrix:
    def __init__(self, labels_name, normalize=True):
        '''
            normalize:是否设元素为百分比形式
        '''
        self.normalize = normalize
        self.labels_name = labels_name
        self.num_classes = len(labels_name)
        self.matrix = np.zeros((self.num_classes, self.num_classes), dtype='float32')

    def update(self, predicts, labels):
        '''
            :param predicts: 一维预测向量,eg:array([0,5,1,6,3,...],dtype=int64)
            :param labels:   一维标签向量:eg:array([0,5,0,6,2,...],dtype=int64)
            :return:
        '''
        for predict, label in zip(predicts, labels):
            self.matrix[predict, label] += 1

    def getMatrix(self,normalize=True):
        '''
            根据传入的normalize判断要进行percent的转换,
            如果normalize为True,则矩阵元素转换为百分比形式,
            如果normalize为False,则矩阵元素就为数量
            Returns:返回一个以百分比或者数量为元素的矩阵
        '''
        if normalize:
            per_sum = self.matrix.sum(axis=1)                   # 计算每行的和,用于百分比计算
            for i in range(self.num_classes):
                self.matrix[i] = (self.matrix[i] / per_sum[i])  # 百分比转换
            self.matrix = np.around(self.matrix, 2)             # 保留2位小数点
            self.matrix[np.isnan(self.matrix)] = 0              # 可能存在NaN,将其设为0
        return self.matrix

    def drawMatrix(self):
        self.matrix = self.getMatrix(self.normalize)

        # plt.figure(figsize=(300,300))
        plt.imshow(self.matrix, cmap=plt.cm.summer)              # 仅画出颜色格子,没有值
        plt.title("Normalized confusion matrix")                # title
        plt.xlabel("Predict label")
        plt.ylabel("Truth label")
        plt.yticks(range(self.num_classes), self.labels_name)   # y轴标签
        plt.xticks(range(self.num_classes), self.labels_name, rotation=90)  # x轴标签

        for x in range(self.num_classes):
            for y in range(self.num_classes):
                value = float(format('%.2f' % self.matrix[y, x]))  # 数值处理
                plt.text(x, y, value, verticalalignment='center', horizontalalignment='center')                  # 写值

        plt.tight_layout()  # 自动调整子图参数,使之填充整个图像区域
        plt.colorbar()      # 色条
        plt.savefig('./summer_ConfusionMatrix.png', bbox_inches='tight')
        # bbox_inches='tight'可确保标签信息显示全
        plt.show()



def inference(model, dataloaders, class_names):
    drawconfusionmatrix = DrawConfusionMatrix(labels_name=class_names)  # 实例化

    since = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    features_list = []
    predict_list = []
    lable_list = []

    # for phase in ['train', 'test']:
    t1 = time.time()
    for phase in ['test']:
        model.eval()   # Set model to evaluate mode

        # Iterate over data.
        running_corrects = 0
        i = 0
        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # forward
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)

            # 将新批次的predict和label更新(保存)
            predict_np = preds.cpu().numpy()
            labels_np = labels.data.cpu().numpy()
            drawconfusionmatrix.update(predict_np, labels_np)

        accuracy = running_corrects.double() / len(dataloaders[phase].dataset)
        print('{} : Acc = {:.4f}'.format(phase, accuracy))

    t2 = time.time()
    print('-' * 35)
    print(f'Start Time : {since}')
    print(f'End Time : {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')
    print(f'Inference Time : {(t2-t1)/60:.4f} minute  ({(t2-t1):.4f}s)')
    print('=' * 39, '\n')

    drawconfusionmatrix.drawMatrix()                # 根据所有predict和label,画出混淆矩阵
    confusion_mat=drawconfusionmatrix.getMatrix()   # 你也可以使用该函数获取混淆矩阵(ndarray)


# prepare dataset
image_datasets = {x: datasets.ImageFolder(Config.data_dir / x,
                    transform=DataTransforms[x]) for x in ['train', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
                  batch_size=Config.batch_size, shuffle=True, num_workers=4)
                  for x in ['train', 'test']}
class_names = image_datasets['train'].classes

# prepare net and load checkpoint
net = DRP(Config.backbone, len(class_names))
net.load_state_dict(torch.load(Config.checkpoint)["state_dict"])

# assign device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = net.to(device)
inference(model_ft, dataloaders, class_names)

你可能感兴趣的:(深度学习,分类,矩阵,python)