python是用matplotlib和seaborn.heatmap()绘制混淆矩阵

本文主要是对自己做实验需要绘制混淆矩阵从而做一个简单的记录。主要解决以下几个问题:

  1. 矩阵数据以行为基准进行归一化;
  2. 显示x轴、y轴的真实标签而不是数字;
  3. 调整轴上标签字体的样式、字号、显示方向、对齐方式等;
  4. 标签文字过长,默认画布无法显示完整。
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

label_txt = ['drink water','eat meal/snack','brushing teeth','brushing hair','drop','pickup',
'throw','sitting down','standing up (from sitting position)','clapping','reading','writing',
'tear up paper','wear jacket','take off jacket','wear a shoe','take off a shoe','wear on glasses'
,'take off glasses','put on a hat/cap','take off a hat/cap','cheer up','hand waving','kicking something'
,'reach into pocket','hopping (one foot jumping)','jump up','make a phone call/answer phone'
,'playing with phone/tablet','typing on a keyboard','pointing to something with finger'
,'taking a selfie','check time (from watch)','rub two hands together','nod head/bow','shake head'
,'wipe face','salute','put the palms together','cross hands in front (say stop)','sneeze/cough'
,'staggering','falling','touch head (headache)','touch chest (stomachache/heart pain)'
,'touch back (backache)','touch neck (neckache)','nausea or vomiting condition'
,'use a fan (with hand or paper)/feeling warm','punching/slapping other person'
,'kicking other person','pushing other person','pat on back of other person'
,'point finger at the other person','hugging other person','giving something to other person'
,'touch other person\'s pocket','handshaking','walking towards each other','walking apart from each other']

# for confusion matrix
sns.set()
fig, ax = plt.subplots()
# 手动生成一个60*60且值为0到60的矩阵
cm = np.random.randint(0, 60, (60, 60))
'''
如果有真实的训练数据,可直接根据函数得到目标矩阵
label为真实值,preds_label为测试集对应的预测标签值
cm = confusion_matrix(label, preds_label)
'''
# 以行为基准将数据化为0-1之间的小数
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
# 使用heatmap()函数绘制混淆矩阵,具体使用方法自行搜索
sns.heatmap(cm_normalized, ax=ax, cmap="YlGnBu", linewidths=0.3, square=True,
    xticklabels=1, yticklabels=1) # 画热力图
# ax.set_title('confusion matrix') #标题
# ax.set_xlabel('predict') #x轴
# ax.set_ylabel('true') #y轴
# 绘制真实标签值,同时设置显示字体,字号,以及字体显示方向,对齐方式等
ax.set_xticklabels(label_txt, rotation=50, horizontalalignment='left', family='Times New Roman', fontsize=5)
ax.set_yticklabels(label_txt, rotation=0, family='Times New Roman', fontsize=5)
# 让x轴的标签显示的图表的上方
ax.xaxis.set_ticks_position("top")
# 让图标自适应,防止标签文字过长而图显示不完整
fig.tight_layout()
# 保存图片并设置清晰度
fig.savefig('test.png', dpi=500)
plt.show()

你可能感兴趣的:(python,深度学习,数据可视化,混淆矩阵)