Matplotlib绘图——混淆矩阵

import warnings
warnings.filterwarnings('ignore') 

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams

# 设置全局字体及大小,设置公式字体
config = {
    "font.family":'serif',        # 衬线字体
    "font.size": 12,              # 相当于小四大小
    "mathtext.fontset":'stix',    # matplotlib渲染数学字体时使用的字体,和Times New Roman差别不大
    "font.serif": ['SimSun'],     # 宋体SimSun
    "axes.unicode_minus": False,  # 用来正常显示负号
    "xtick.direction":'in',       # 横坐标轴的刻度设置向内(in)或向外(out)
    "ytick.direction":'in',       # 纵坐标轴的刻度设置向内(in)或向外(out)
}
rcParams.update(config)

%matplotlib inline 
# 内置魔法函数,不用再 plt.show()
classes = ['一类','二类','三类','四类']
confusion_matrix = np.array([(900, 300, 200, 100),
                             (100, 800, 300, 200),
                             (300, 200, 700, 100),
                             (200, 100, 300, 600)],dtype=np.int)   # 混淆矩阵

proportion = []
for i in confusion_matrix:
    for j in i:
        temp=j/(np.sum(i))
        proportion.append(temp)

pshow = []
for i in proportion:
    pt="%.2f%%" % (i * 100)
    pshow.append(pt)
proportion = np.array(proportion).reshape(confusion_matrix.shape[0], confusion_matrix.shape[1])
pshow = np.array(pshow).reshape(confusion_matrix.shape[0], confusion_matrix.shape[1])   # reshape(列的长度,行的长度)

plt.figure(figsize=(5,3))
plt.imshow(proportion, interpolation='nearest', cmap=plt.cm.Blues)  # 按照像素显示出矩阵
plt.colorbar().ax.tick_params(labelsize=10)    # 设置右侧色标刻度大小

tick_marks = np.arange(len(classes))   # [0, 1, 2, 3]
plt.xticks(tick_marks, classes, fontsize=10)
plt.yticks(tick_marks, classes, fontsize=10)
ax = plt.gca()
# 设置 横轴 刻度 标签 显示在顶部
ax.tick_params(axis="x", top=True, labeltop=True, bottom=False, labelbottom=False)  

# thresh = confusion_matrix.max() / 2.
# ij配对,遍历矩阵迭代器
iters = np.reshape([[[i,j] for j in range(4)] for i in range(4)],(confusion_matrix.size,2))
for i, j in iters:
    if(i==j):
        # 仅居中显示数字
        # plt.text(j, i, format(confusion_matrix[i, j]), va='center', ha='center', fontsize=10,color='white',weight=5)  
        
        # 同时居中显示数字和百分比
        plt.text(j, i-0.12, format(confusion_matrix[i, j]), va='center', ha='center', fontsize=10,color='white',weight=5)  # 显示数字
        plt.text(j, i+0.12, pshow[i, j], va='center', ha='center', fontsize=10, color='white')  # 显示百分比
    else:
        # 仅居中显示数字
        # plt.text(j, i, format(confusion_matrix[i, j]),va='center',ha='center',fontsize=10)
        
        # 同时居中显示数字和百分比
        plt.text(j, i-0.12, format(confusion_matrix[i, j]),va='center',ha='center',fontsize=10)   #显示数字
        plt.text(j, i+0.12, pshow[i, j], va='center', ha='center', fontsize=10)  # 显示百分比

# plt.title('confusion_matrix')
plt.ylabel('实际', fontsize=12)
plt.xlabel('预测', fontsize=12)
ax = plt.gca()
# 设置 横轴标签 显示在顶部
ax.xaxis.set_label_position('top')   
plt.tight_layout()    # 自动调整子图参数,使之填充整个图像区域,并且防止子图标签堆叠

plt.savefig(r'D:\Users\Administrator\Desktop\混淆矩阵.png', dpi=600, bbox_inches='tight')
plt.show()

Matplotlib绘图——混淆矩阵_第1张图片
参考资料:
https://blog.csdn.net/m0_46295727/article/details/123143537
https://blog.csdn.net/weixin_43818631/article/details/121309660

你可能感兴趣的:(Python学习,学习笔记,python,程序人生,经验分享,其他)