主要是汇总几种关于多分类问题中的混淆矩阵可视化Python 实现.
最简单的一种是直接在终端打印混淆矩阵结果,如:import sys
def confusion_matrix(gt_labels, pred_labels, num_labels):
from sklearn.metrics import confusion_matrix
conf_matrix = confusion_matrix(gt_labels, pred_labels, labels=range(num_labels))
sys.stdout.write('\n\nConfusion Matrix')
sys.stdout.write('\t'*(num_labels-2)+'| Accuracy')
sys.stdout.write('\n'+'-'*8*(num_labels+1))
sys.stdout.write('\n')
for i in range(len(conf_matrix)):
for j in range(len(conf_matrix[i])):
sys.stdout.write(str(conf_matrix[i][j].astype(np.int))+'\t')
sys.stdout.write('| %3.2f %%' % (conf_matrix[i][i]*100 / conf_matrix[i].sum()))
sys.stdout.write('\n')
sys.stdout.write('Number of test samples: %i \n\n' % conf_matrix.sum())
1. 示例1from sklearn.metrics import confusion_matrix
labels = ['business', 'health']
cm = confusion_matrix(y_test, pred, labels)
print(cm)
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(cm)
plt.title('Confusion matrix of the classifier')
fig.colorbar(cax)
ax.set_xticklabels([''] + labels)
ax.set_yticklabels([''] + labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()
2. 示例2def plot_confusion_matrix(cm,
target_names,
title='Confusion matrix',
cmap=None,
normalize=True):
"""
given a sklearn confusion matrix (cm), make a nice plot
Arguments
---------
cm: confusion matrix from sklearn.metrics.confusion_matrix
target_names: given classification classes such as [0, 1, 2]
the class names, for example: ['high', 'medium', 'low']
title: the text to display at the top of the matrix
cmap: the gradient of the values displayed from matplotlib.pyplot.cm
see:
http://matplotlib.org/examples/color/colormaps_reference.html
plt.get_cmap('jet') or plt.cm.Blues
normalize: If False, plot the raw numbers
If True, plot the proportions
Usage
-----
plot_confusion_matrix(cm = cm,
normalize = True, # show proportions
target_names = y_labels_vals, # list of classes names
title = best_estimator_name) # title of graph