Confusion Matrix (混淆矩阵) 的绘制
import cv2
import time
import numpy as np
from pathlib import Path
from datetime import datetime
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import datasets, transforms
from models.drp import DRP
from models.trformer_dual import TRFormer as trformer_dual
class Config:
image_resize = 256
image_crop = 224
batch_size = 32
backbone = "resnet50"
dataset_name = "dtd_t-SNE"
data_dir = Path(f"data/{dataset_name}/splits/split_1")
output_dir = Path(f"outputs/confusion_matrix")
checkpoint = "outputs/drp_181.pth.tar"
DataTransforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(Config.image_crop),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]),
'test' : transforms.Compose([
transforms.Resize(Config.image_resize),
transforms.CenterCrop(Config.image_crop),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}
class DrawConfusionMatrix:
def __init__(self, labels_name, normalize=True):
'''
normalize:是否设元素为百分比形式
'''
self.normalize = normalize
self.labels_name = labels_name
self.num_classes = len(labels_name)
self.matrix = np.zeros((self.num_classes, self.num_classes), dtype='float32')
def update(self, predicts, labels):
'''
:param predicts: 一维预测向量,eg:array([0,5,1,6,3,...],dtype=int64)
:param labels: 一维标签向量:eg:array([0,5,0,6,2,...],dtype=int64)
:return:
'''
for predict, label in zip(predicts, labels):
self.matrix[predict, label] += 1
def getMatrix(self,normalize=True):
'''
根据传入的normalize判断要进行percent的转换,
如果normalize为True,则矩阵元素转换为百分比形式,
如果normalize为False,则矩阵元素就为数量
Returns:返回一个以百分比或者数量为元素的矩阵
'''
if normalize:
per_sum = self.matrix.sum(axis=1) # 计算每行的和,用于百分比计算
for i in range(self.num_classes):
self.matrix[i] = (self.matrix[i] / per_sum[i]) # 百分比转换
self.matrix = np.around(self.matrix, 2) # 保留2位小数点
self.matrix[np.isnan(self.matrix)] = 0 # 可能存在NaN,将其设为0
return self.matrix
def drawMatrix(self):
self.matrix = self.getMatrix(self.normalize)
# plt.figure(figsize=(300,300))
plt.imshow(self.matrix, cmap=plt.cm.summer)
plt.title("Normalized confusion matrix")
plt.xlabel("Predict label")
plt.ylabel("Truth label")
plt.yticks(range(self.num_classes), self.labels_name)
plt.xticks(range(self.num_classes), self.labels_name, rotation=90)
for x in range(self.num_classes):
for y in range(self.num_classes):
value = float(format('%.2f' % self.matrix[y, x]))
plt.text(x, y, value, verticalalignment='center', horizontalalignment='center')
plt.tight_layout()
plt.colorbar()
plt.savefig('./summer_ConfusionMatrix.png', bbox_inches='tight')
plt.show()
def inference(model, dataloaders, class_names):
drawconfusionmatrix = DrawConfusionMatrix(labels_name=class_names)
since = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
features_list = []
predict_list = []
lable_list = []
t1 = time.time()
for phase in ['test']:
model.eval()
running_corrects = 0
i = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
running_corrects += torch.sum(preds == labels.data)
predict_np = preds.cpu().numpy()
labels_np = labels.data.cpu().numpy()
drawconfusionmatrix.update(predict_np, labels_np)
accuracy = running_corrects.double() / len(dataloaders[phase].dataset)
print('{} : Acc = {:.4f}'.format(phase, accuracy))
t2 = time.time()
print('-' * 35)
print(f'Start Time : {since}')
print(f'End Time : {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')
print(f'Inference Time : {(t2-t1)/60:.4f} minute ({(t2-t1):.4f}s)')
print('=' * 39, '\n')
drawconfusionmatrix.drawMatrix()
confusion_mat=drawconfusionmatrix.getMatrix()
image_datasets = {x: datasets.ImageFolder(Config.data_dir / x,
transform=DataTransforms[x]) for x in ['train', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=Config.batch_size, shuffle=True, num_workers=4)
for x in ['train', 'test']}
class_names = image_datasets['train'].classes
net = DRP(Config.backbone, len(class_names))
net.load_state_dict(torch.load(Config.checkpoint)["state_dict"])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = net.to(device)
inference(model_ft, dataloaders, class_names)