python 绘制混淆矩阵

首先看一下文件夹下包含的文件
在这里插入图片描述
其中csv文件是训练NN时,eval某个结果保留下来的混淆矩阵
python 绘制混淆矩阵_第1张图片
代码如下

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
np.random.seed(0)

x = np.loadtxt(open("ford1_confusion_matrix.csv","rb"),delimiter=",",skiprows=0)
# print(x)

b = np.sum(x,axis=1)
# print(b)
b = b.repeat(5).reshape(5,5) #主要是为了统计每一类的准确率
# print(b)

x = x/b
# print(x)


f, (ax1, ax2) = plt.subplots(figsize=(8,8),nrows=2)

sns.heatmap(x, annot=True, ax=ax1,cmap="YlGnBu")
sns.heatmap(x, annot=True, ax=ax2,cmap="YlGnBu", annot_kws={'size':9,'weight':'bold', 'color':'blue'})
# Keyword arguments for ax.text when annot is True.
# http://stackoverflow.com/questions/35024475/seaborn-heatmap-key-words
# plt.show()
f.savefig('test.jpg')

最终的实验结果为:
python 绘制混淆矩阵_第2张图片

撒花撒花~~~

你可能感兴趣的:(python)