Python 绘制混淆矩阵

这篇的文章的好多代码都源自于博客,我只是把他们重新整合,然后变成了我需要的漂亮的,适合放在论文中图片代码。 

参考链接:

https://blog.csdn.net/weixin_38314865/article/details/88989506

https://www.cnblogs.com/ZHANG576433951/p/11233159.html

https://blog.csdn.net/qq_37851620/article/details/100642566?utm_source=app&app_version=4.7.1

https://blog.csdn.net/Poul_henry/article/details/88294297

https://mathpretty.com/10675.html

import numpy as np
import itertools
import matplotlib.pyplot as plt


# 绘制混淆矩阵
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    Input
    - cm : 计算出的混淆矩阵的值
    - classes : 混淆矩阵中每一行每一列对应的列
    - normalize : True:显示百分比, False:显示个数
    """
    if normalize:
        matrix = cm
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')
    plt.figure()
    # 设置输出的图片大小
    figsize = 8, 6
    figure, ax = plt.subplots(figsize=figsize)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    # 设置title的大小以及title的字体
    font_title= {'family': 'Times New Roman',
                  'weight': 'normal',
                  'size': 15,
                  }
    plt.title(title,fontdict=font_title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45,)
    plt.yticks(tick_marks, classes)
    # 设置坐标刻度值的大小以及刻度值的字体
    plt.tick_params(labelsize=15)
    labels = ax.get_xticklabels() + ax.get_yticklabels()
    print (labels)
    [label.set_fontname('Times New Roman') for label in labels]
    if normalize:
        fm_int = 'd'
        fm_float = '.3%'
        thresh = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, format(cm[i, j], fm_float),
                     horizontalalignment="center", verticalalignment='bottom',family = "Times New Roman", weight = "normal",size = 15,
                     color="white" if cm[i, j] > thresh else "black")
            plt.text(j, i, format(matrix[i, j], fm_int),
                     horizontalalignment="center", verticalalignment='top',family = "Times New Roman", weight = "normal",size = 15,
                     color="white" if cm[i, j] > thresh else "black")
    else:
        fm_int = 'd'
        thresh = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, format(cm[i, j], fm_int),
                     horizontalalignment="center", verticalalignment='bottom',
                     color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    # 设置横纵坐标的名称以及对应字体格式
    # font_lable = {'family': 'Times New Roman',
    #               'weight': 'normal',
    #               'size': 15,
    #               }
    # plt.ylabel('True label', font_lable)
    # plt.xlabel('Predicted label', font_lable)

    plt.savefig('confusion_matrix.eps', dpi=600, format='eps')
    plt.savefig('confusion_matrix.png', dpi=600, format='png')


cnf_matrix = np.array([[109653, 2, 0, 1, 0],
                       [0, 104180, 2, 0, 0],
                       [1, 0, 110422, 1, 0],
                       [9, 1, 1, 104380, 0],
                       [13, 0, 0, 3, 767875]])
attack_types = ['Normal', 'DoS', 'Probe', 'R2L', 'U2R']

plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title='Normalized confusion matrix')

 

你可能感兴趣的:(网络安全与人工智能,Python)