混淆矩阵是分类任务常用的一种评估方法。对角线元素表示预测标签等于真实标签的点数,而非对角线元素则是分类器未正确标记的点的数量。 混淆矩阵的对角线值越高越好,表明有许多正确的预测。1
尤其是在类别数量不平衡的情况下,相比accuracy,混淆矩阵(confusion matrix)对哪个类被错误分类具有更直观的解释。
在平时做简单的数据实验时,可以仅用from sklearn.metrics import plot_confusion_matrix
或者seaborn
对混淆矩阵进行可视化。但是在深度学习训练模型的过程中,在tensorboard
中可视化混淆矩阵会更方便结果记录和对照。
代码参考facebook的SlowFast工程2:
import itertools
import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.metrics import confusion_matrix
从pytorch模型输出的预测结果preds、真值labels,计算混淆矩阵。
def get_confusion_matrix(preds, labels, num_classes, normalize="true"):
"""
Calculate confusion matrix on the provided preds and labels.
Args:
preds (tensor or lists of tensors): predictions. Each tensor is in
in the shape of (n_batch, num_classes). Tensor(s) must be on CPU.
labels (tensor or lists of tensors): corresponding labels. Each tensor is
in the shape of either (n_batch,) or (n_batch, num_classes).
num_classes (int): number of classes. Tensor(s) must be on CPU.
normalize (Optional[str]) : {‘true’, ‘pred’, ‘all’}, default="true"
Normalizes confusion matrix over the true (rows), predicted (columns)
conditions or all the population. If None, confusion matrix
will not be normalized.
Returns:
cmtx (ndarray): confusion matrix of size (num_classes x num_classes)
"""
if isinstance(preds, list):
preds = torch.cat(preds, dim=0)
if isinstance(labels, list):
labels = torch.cat(labels, dim=0)
# If labels are one-hot encoded, get their indices.
if labels.ndim == preds.ndim:
labels = torch.argmax(labels, dim=-1)
# Get the predicted class indices for examples.
preds = torch.flatten(torch.argmax(preds, dim=-1))
labels = torch.flatten(labels)
cmtx = confusion_matrix(
labels, preds, labels=list(range(num_classes)))#, normalize=normalize) 部分版本无该参数
return cmtx
输入get_confusion_matrix获取的混淆矩阵cmtx,类别数量和类别名称,进行混淆矩阵绘制。
def plot_confusion_matrix(cmtx, num_classes, class_names=None, figsize=None):
"""
A function to create a colored and labeled confusion matrix matplotlib figure
given true labels and preds.
Args:
cmtx (ndarray): confusion matrix.
num_classes (int): total number of classes.
class_names (Optional[list of strs]): a list of class names.
figsize (Optional[float, float]): the figure size of the confusion matrix.
If None, default to [6.4, 4.8].
Returns:
img (figure): matplotlib figure.
"""
if class_names is None or type(class_names) != list:
class_names = [str(i) for i in range(num_classes)]
figure = plt.figure(figsize=figsize)
plt.imshow(cmtx, interpolation="nearest", cmap=plt.cm.Blues)
plt.title("Confusion matrix")
plt.colorbar()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
plt.yticks(tick_marks, class_names)
# Use white text if squares are dark; otherwise black.
threshold = cmtx.max() / 2.0
for i, j in itertools.product(range(cmtx.shape[0]), range(cmtx.shape[1])):
color = "white" if cmtx[i, j] > threshold else "black"
plt.text(
j,
i,
format(cmtx[i, j], ".2f") if cmtx[i, j] != 0 else ".",
horizontalalignment="center",
color=color,
)
plt.tight_layout()
plt.ylabel("True label")
plt.xlabel("Predicted label")
return figure
将plot_confusion_matrix返回的绘制图像显示在tensorboard中。
from torch.utils.tensorboard import SummaryWriter
def add_confusion_matrix(
writer,
cmtx,
num_classes,
global_step=None,
subset_ids=None,
class_names=None,
tag="Confusion Matrix",
figsize=None,
):
"""
Calculate and plot confusion matrix to a SummaryWriter.
Args:
writer (SummaryWriter): the SummaryWriter to write the matrix to.
cmtx (ndarray): confusion matrix.
num_classes (int): total number of classes.
global_step (Optional[int]): current step.
subset_ids (list of ints): a list of label indices to keep.
class_names (list of strs, optional): a list of all class names.
tag (str or list of strs): name(s) of the confusion matrix image.
figsize (Optional[float, float]): the figure size of the confusion matrix.
If None, default to [6.4, 4.8].
"""
if subset_ids is None or len(subset_ids) != 0:
# If class names are not provided, use class indices as class names.
if class_names is None:
class_names = [str(i) for i in range(num_classes)]
# If subset is not provided, take every classes.
if subset_ids is None:
subset_ids = list(range(num_classes))
sub_cmtx = cmtx[subset_ids, :][:, subset_ids]
sub_names = [class_names[j] for j in subset_ids]
sub_cmtx = plot_confusion_matrix(
sub_cmtx,
num_classes=len(subset_ids),
class_names=sub_names,
figsize=figsize,
)
# Add the confusion matrix image to writer.
writer.add_figure(tag=tag, figure=sub_cmtx, global_step=global_step)
model.train()
# 预测值和标注值,用于绘制混淆矩阵
preds=[]
labels=[]
for i, (inputs, targets) in enumerate(data_loader):
targets = targets.to(device, non_blocking=True)#shape: (n_batch,)
try:
outputs = model(inputs)#shape: (n_batch,n_classes)
loss = criterion(outputs, targets)
# 需将tensor从gpu转到cpu上
preds.append(outputs.cpu())
labels.append(targets.cpu())
acc,recall = calculate_precision_and_recall(outputs, targets,pos_label=0)
losses.update(float(loss.item()), inputs.size(0))
accuracies.update(float(acc), inputs.size(0))
recalls.update(float(recall), inputs.size(0))
#total_loss+=float(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
"""
混淆矩阵可视化
"""
preds = torch.cat(preds,dim=0)
labels = torch.cat(labels,dim=0)
cmtx = get_confusion_matrix(preds,labels,len(class_names))
add_confusion_matrix(tb_writer,cmtx,num_classes=len(class_names),class_names=class_names,tag="Train Confusion Matrix",figsize=[10,8])
https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html ↩︎
https://github.com/facebookresearch/SlowFast/tree/master/slowfast/visualization ↩︎