首先看一下文件夹下包含的文件
其中csv文件是训练NN时,eval某个结果保留下来的混淆矩阵
代码如下
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
np.random.seed(0)
x = np.loadtxt(open("ford1_confusion_matrix.csv","rb"),delimiter=",",skiprows=0)
# print(x)
b = np.sum(x,axis=1)
# print(b)
b = b.repeat(5).reshape(5,5) #主要是为了统计每一类的准确率
# print(b)
x = x/b
# print(x)
f, (ax1, ax2) = plt.subplots(figsize=(8,8),nrows=2)
sns.heatmap(x, annot=True, ax=ax1,cmap="YlGnBu")
sns.heatmap(x, annot=True, ax=ax2,cmap="YlGnBu", annot_kws={'size':9,'weight':'bold', 'color':'blue'})
# Keyword arguments for ax.text when annot is True.
# http://stackoverflow.com/questions/35024475/seaborn-heatmap-key-words
# plt.show()
f.savefig('test.jpg')
撒花撒花~~~