python 绘制混淆矩( confusion matrix)

代码

import matplotlib.pyplot as plt
import numpy as np

classes = ['ang', 'hap', 'neu', 'sad']#标签列表
confusion_matrix = np.array(([91, 1, 4, 2], [6, 92, 2, 2], [2, 3, 92, 3], [8, 13, 4, 90]))#二维混淆矩阵

plt.imshow(confusion_matrix, interpolation='nearest', cmap=plt.cm.Oranges)  # 按照像素显示出矩阵
plt.title('confusion_matrix')
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=0)  # 倾斜
plt.yticks(tick_marks, classes)

thresh = confusion_matrix.max() / 2.
# iters = [[i,j] for i in range(len(classes)) for j in range((classes))]
# ij配对,遍历矩阵迭代器
iters = np.reshape([[[i, j] for j in range(4)] for i in range(4)], (confusion_matrix.size, 2))
for i, j in iters:
    plt.text(j, i, format(confusion_matrix[i, j]), va='center', ha='center')  # 显示对应的数字

plt.ylabel('Real label')
plt.xlabel('Prediction')
plt.tight_layout()
# plt.show()
plt.savefig('confusion_matrix2.png', format='png')

结果展示

python 绘制混淆矩( confusion matrix)_第1张图片

你可能感兴趣的:(机器学习,python,matlab,机器学习)