混淆矩阵绘制

代码有些参考了其他博客

from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

import itertools
def plot_confusion_matrix(cm, classes, normalize=False, cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    Input
    - cm : 计算出的混淆矩阵的值
    - classes : 混淆矩阵中每一行每一列对应的列
    - normalize : True:显示百分比, False:显示个数
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        title = 'Normalized confusion matrix'
        cb_title = 'Number of recordings (normalized)'
    else:
        title = 'Confusion matrix'
        cb_title = 'Number of recordings'
    print(title)
    print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    cb = plt.colorbar()
    cb.set_label(cb_title)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90)
    plt.yticks(tick_marks, classes)

    plt.axis("equal")
    ax = plt.gca()  # 获得当前axis
    left, right = plt.xlim()  # 获得x轴最大最小值
    ax.spines['left'].set_position(('data', left))
    ax.spines['right'].set_position(('data', right))
    for edge_i in ['top', 'bottom', 'right', 'left']:
        ax.spines[edge_i].set_edgecolor("white")

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        num =  np.round(cm[i, j], 2) if normalize else int(cm[i, j])  #'{:.2f}'.format(cm[i, j])
        plt.text(j, i, num,
                 verticalalignment='center',
                 horizontalalignment="center",
                 color="white" if num > thresh else "black")
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()

测试

from sklearn.metrics import confusion_matrix

# pytorch 模型输出的gpu tensor转换到cpu上
y_true = true_all.cpu().numpy()
y_pred = pred_all.cpu().numpy()
cm = confusion_matrix(y_true, y_pred)

# 输出数量的混淆矩阵
plot_confusion_matrix(cm, classes, normalize=False)
# 输出数量归一化后的混淆矩阵
plot_confusion_matrix(cm, classes, normalize=True)

你可能感兴趣的:(深度学习-分类,深度学习,机器学习)