这篇的文章的好多代码都源自于博客,我只是把他们重新整合,然后变成了我需要的漂亮的,适合放在论文中图片代码。
参考链接:
https://blog.csdn.net/weixin_38314865/article/details/88989506
https://www.cnblogs.com/ZHANG576433951/p/11233159.html
https://blog.csdn.net/qq_37851620/article/details/100642566?utm_source=app&app_version=4.7.1
https://blog.csdn.net/Poul_henry/article/details/88294297
https://mathpretty.com/10675.html
import numpy as np
import itertools
import matplotlib.pyplot as plt
# 绘制混淆矩阵
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
Input
- cm : 计算出的混淆矩阵的值
- classes : 混淆矩阵中每一行每一列对应的列
- normalize : True:显示百分比, False:显示个数
"""
if normalize:
matrix = cm
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
plt.figure()
# 设置输出的图片大小
figsize = 8, 6
figure, ax = plt.subplots(figsize=figsize)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
# 设置title的大小以及title的字体
font_title= {'family': 'Times New Roman',
'weight': 'normal',
'size': 15,
}
plt.title(title,fontdict=font_title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45,)
plt.yticks(tick_marks, classes)
# 设置坐标刻度值的大小以及刻度值的字体
plt.tick_params(labelsize=15)
labels = ax.get_xticklabels() + ax.get_yticklabels()
print (labels)
[label.set_fontname('Times New Roman') for label in labels]
if normalize:
fm_int = 'd'
fm_float = '.3%'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fm_float),
horizontalalignment="center", verticalalignment='bottom',family = "Times New Roman", weight = "normal",size = 15,
color="white" if cm[i, j] > thresh else "black")
plt.text(j, i, format(matrix[i, j], fm_int),
horizontalalignment="center", verticalalignment='top',family = "Times New Roman", weight = "normal",size = 15,
color="white" if cm[i, j] > thresh else "black")
else:
fm_int = 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fm_int),
horizontalalignment="center", verticalalignment='bottom',
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
# 设置横纵坐标的名称以及对应字体格式
# font_lable = {'family': 'Times New Roman',
# 'weight': 'normal',
# 'size': 15,
# }
# plt.ylabel('True label', font_lable)
# plt.xlabel('Predicted label', font_lable)
plt.savefig('confusion_matrix.eps', dpi=600, format='eps')
plt.savefig('confusion_matrix.png', dpi=600, format='png')
cnf_matrix = np.array([[109653, 2, 0, 1, 0],
[0, 104180, 2, 0, 0],
[1, 0, 110422, 1, 0],
[9, 1, 1, 104380, 0],
[13, 0, 0, 3, 767875]])
attack_types = ['Normal', 'DoS', 'Probe', 'R2L', 'U2R']
plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title='Normalized confusion matrix')