Python Tips(二):如何优雅的画混淆矩阵

在进行处理分类问题,常常需要画混淆矩阵对数据分类情况进行分析,这里安利一个混淆矩阵的方法:

1.首先导入要用到的包:

import numpy as np
import pandas as pd
import matplotlib.pyplot as pl
from sklearn import metrics
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

2.定义混淆矩阵函数,进行相关参数设置:

def plot_confusion_matrix(cm,
                          target_names,
                          title='Confusion matrix',
                          cmap='Blues',#这个地方设置混淆矩阵的颜色主题,这个主题看着就干净~
                          normalize=True):
   
 
    accuracy = np.trace(cm) / float(np.sum(cm))
    misclass = 1 - accuracy
    
    if cmap is None:
        cmap = plt.get_cmap('Blues')
        
    plt.figure(figsize=(9, 7))
#    plt.figure()
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names, rotation=45)
        plt.yticks(tick_marks, target_names)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]


    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
        else:
            plt.text(j, i, "{:,}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label',size=15)
    plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass),size=15)
    plt.savefig('./Confusion_Matrix.png', format='png',bbox_inches = 'tight')
    plt.show()

3.显示混淆矩阵:

def plot_confuse(model, x_val, y_val, labels):
    predictions = model.predict_classes(x_val,batch_size=1)
    truelabel = y_val.argmax(axis=-1)   # 将one-hot转化为label
    conf_mat = confusion_matrix(y_true=truelabel, y_pred=predictions)
    plt.figure()
    
    plot_confusion_matrix(conf_mat, normalize=False,target_names=labels,title='Confusion Matrix')

4.执行函数操作:

predicted_label = np.argmax(loaded_model.predict(X_test), axis=-1)
Y_test = np.argmax(Y_test,axis=-1)


Y_test = Y_test.tolist()
predicted_label = predicted_label.tolist()

conf_mat = confusion_matrix(y_true=Y_test, y_pred=predicted_label)

plot_confusion_matrix(conf_mat, normalize=False,target_names=['1','2','3','4','5','6'],title='Confusion Matrix')

效果如下:
Python Tips(二):如何优雅的画混淆矩阵_第1张图片

你可能感兴趣的:(python)