1.首先导入要用到的包:
import numpy as np
import pandas as pd
import matplotlib.pyplot as pl
from sklearn import metrics
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
2.定义混淆矩阵函数,进行相关参数设置:
def plot_confusion_matrix(cm,
target_names,
title='Confusion matrix',
cmap='Blues',#这个地方设置混淆矩阵的颜色主题,这个主题看着就干净~
normalize=True):
accuracy = np.trace(cm) / float(np.sum(cm))
misclass = 1 - accuracy
if cmap is None:
cmap = plt.get_cmap('Blues')
plt.figure(figsize=(9, 7))
# plt.figure()
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
if target_names is not None:
tick_marks = np.arange(len(target_names))
plt.xticks(tick_marks, target_names, rotation=45)
plt.yticks(tick_marks, target_names)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
thresh = cm.max() / 1.5 if normalize else cm.max() / 2
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
if normalize:
plt.text(j, i, "{:0.4f}".format(cm[i, j]),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
else:
plt.text(j, i, "{:,}".format(cm[i, j]),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label',size=15)
plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass),size=15)
plt.savefig('./Confusion_Matrix.png', format='png',bbox_inches = 'tight')
plt.show()
3.显示混淆矩阵:
def plot_confuse(model, x_val, y_val, labels):
predictions = model.predict_classes(x_val,batch_size=1)
truelabel = y_val.argmax(axis=-1) # 将one-hot转化为label
conf_mat = confusion_matrix(y_true=truelabel, y_pred=predictions)
plt.figure()
plot_confusion_matrix(conf_mat, normalize=False,target_names=labels,title='Confusion Matrix')
4.执行函数操作:
predicted_label = np.argmax(loaded_model.predict(X_test), axis=-1)
Y_test = np.argmax(Y_test,axis=-1)
Y_test = Y_test.tolist()
predicted_label = predicted_label.tolist()
conf_mat = confusion_matrix(y_true=Y_test, y_pred=predicted_label)
plot_confusion_matrix(conf_mat, normalize=False,target_names=['1','2','3','4','5','6'],title='Confusion Matrix')