matplot 绘制混淆矩阵

导入包

import pandas as pd
import numpy as np
from sklearn import metrics
import itertools
import matplotlib.pyplot as plt

读取文件

df = pd.read_csv('xxx.csv')
wrong_labels = df['wrong_labels'].tolist()
right_labels = df['right_labels'].tolist()

# 构造混淆矩阵
my_confusion_matrix = metrics.confusion_matrix(np.array(wrong_labels), np.array(right_labels), labels=None, sample_weight=None)

# 所有的分类 label
labels = list(set(right_labels))

绘制混淆矩阵

def plot_Matrix(cm, classes, title=None,  cmap=plt.cm.Blues):
    plt.rc('font',family='sans-serif',size='4.5')   # 设置字体样式、大小
    plt.rcParams['font.sans-serif'] = ['Tahoma', 'DejaVu Sans', 'SimHei', 'Lucida Grande', 'Verdana']  # 用来正常显示中文标签
    plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
    plt.figure(figsize=(200, 200))
    plt.rcParams['figure.dpi'] = 200 #分辨率

    
    # 按行进行归一化
    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    # 占比1%以下的单元格,设为0,防止在最后的颜色中体现出来
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            if int(cm[i, j]*100 + 0.5) == 0:
                cm[i, j]=0

    fig, ax = plt.subplots()
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax) # 侧边的颜色条带
    
    plt.title('Confusion matrix')
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           xticklabels=list(range(len(classes))), yticklabels=list(range(len(classes))),
           title=title,
           ylabel='Actual',
           xlabel='Predicted')

    # 通过绘制格网,模拟每个单元格的边框
    ax.set_xticks(np.arange(cm.shape[1]+1)-.5, minor=True)
    ax.set_yticks(np.arange(cm.shape[0]+1)-.5, minor=True)
    ax.grid(which="minor", color="gray", linestyle='-', linewidth=0.05)
    ax.tick_params(which="minor", bottom=False, left=False)

    # 将x轴上的lables旋转45度
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # 标注百分比信息
    fmt = 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            if int(cm[i, j]*100 + 0.5) > 0:
                ax.text(j, i, format(int(cm[i, j]*100 + 0.5) , fmt) + '%',
                        ha="center", va="center",
                        color="white"  if cm[i, j] > thresh else "black")
    fig.tight_layout()
    plt.show()

plot_Matrix(my_confusion_matrix, labels, title='Normalized confusion matrix')

result

matplot 绘制混淆矩阵_第1张图片

你可能感兴趣的:(算法)