python 画混淆矩阵

import matplotlib as mpl
import random
mpl.use('Agg')

import matplotlib.pyplot as plt
import numpy as np

mpl.use('Agg')
custom_font = mpl.font_manager.FontProperties(fname='H:\competition\downmodel\evaluation\paintpic\zh.ttf')
def plot_confusion_matrix(y_true, y_pred, labels):
    # import matplotlib.pyplot as plt
    from sklearn.metrics import confusion_matrix
    cmap = plt.cm.get_cmap('Accent_r')
    # cmap = plt.cm.binary
    cm = confusion_matrix(y_true, y_pred)
    tick_marks = np.array(range(len(labels))) + 0.5
    np.set_printoptions(precision=2)
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    plt.figure(figsize=(10, 8), dpi=120)
    ind_array = np.arange(len(labels))
    x, y = np.meshgrid(ind_array, ind_array)

    intFlag = 1# 标记在图片中对文字是整数型还是浮点型
    for x_val, y_val in zip(x.flatten(), y.flatten()):
        if (intFlag):
            c = cm[y_val][x_val]
            plt.text(x_val, y_val, "%d" % (c,), color='red', fontsize=8, va='center', ha='center')

        else:
            c = cm_normalized[y_val][x_val]
            if (c > 0.01):
                #这里是绘制数字,可以对数字大小和颜色进行修改
                plt.text(x_val, y_val, "%0.2f" % (c,), color='red', fontsize=7, va='center', ha='center')
            else:
                plt.text(x_val, y_val, "%d" % (0,), color='red', fontsize=7, va='center', ha='center')
    if(intFlag):
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
    else:
        plt.imshow(cm_normalized, interpolation='nearest', cmap=cmap)
    plt.gca().set_xticks(tick_marks, minor=True)
    plt.gca().set_yticks(tick_marks, minor=True)
    plt.gca().xaxis.set_ticks_position('none')
    plt.gca().yaxis.set_ticks_position('none')
    plt.grid(True, which='minor', linestyle='-')
    plt.gcf().subplots_adjust(bottom=0.15)
    plt.title(u'ResNet模型预测混淆矩阵结果', fontproperties=custom_font)
    plt.colorbar()
    # xy = range(0)
    # z = xy
    # sc = plt.scatter(z,z,c=z)
    # plt.colorbar(sc)
    xlocations = np.array(range(len(labels)))
    plt.xticks(xlocations, labels, rotation=90)
    plt.yticks(xlocations, labels)
    plt.ylabel(u'滚动轴承真实类别', fontproperties=custom_font)
    plt.xlabel(u'滚动轴承预测类别', fontproperties=custom_font)
    plt.savefig('confusion_matrix.jpg', dpi=300)
    plt.show()
    print("kappa is:"+str(kappa(cm,10)))


def kappa(testData, k):  # testData表示要计算的混淆数据矩阵,k表示数据矩阵的是k*k的
    dataMat = np.mat(testData)
    P0 = 0.0
    for i in range(k):
        P0 += dataMat[i, i] * 1.0
    xsum = np.sum(dataMat, axis=1)
    ysum = np.sum(dataMat, axis=0)
    # xsum是个k行1列的向量,ysum是个1行k列的向量
    Pe = float(ysum * xsum) / k ** 2
    P0 = float(P0 / k * 1.0)
    cohens_coefficient = float((P0 - Pe) / (1 - Pe))
    return cohens_coefficient

# 针对每个类别给出详细的准确率、召回率和F-值这三个参数和宏平均值,用来评价算法好坏。
def my_classification_report(y_true, y_pred):
    from sklearn.metrics import classification_report
    print("classification_report(left: labels):")
    print(classification_report(y_true, y_pred))

ytrue = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4,
         4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]
ypred = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 2, 2, 5, 5, 2, 2, 2, 2, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4,
         4, 4, 4, 4, 2, 4, 5, 5, 5, 5, 2, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]

listtrue=[]
for i in range(10):
    for j in range(60):
        listtrue.append(i)


listpred=[]

for i in range(10):
    r=random.randint(52,56)
    for j in range(r):
        listpred.append(i)
    for k in range(60-r):
        listpred.append(random.randint(0,9))

plot_confusion_matrix(listtrue,listpred,[0,1,2,3,4,5,6,7,8,9])
my_classification_report(listtrue,listpred)
# plot_confusion_matrix(ytrue,ypred,[1,2,3,4,5,6])

如下图:

python 画混淆矩阵_第1张图片

其中get_cmap中取值可为:Possible values are: Accent, Accent_r, Blues, Blues_r, BrBG, BrBG_r, BuGn, BuGn_r, BuPu, BuPu_r, CMRmap, CMRmap_r, Dark2, Dark2_r, GnBu, GnBu_r, Greens, Greens_r, Greys, Greys_r, OrRd, OrRd_r, Oranges, Oranges_r, PRGn, PRGn_r, Paired, Paired_r, Pastel1, Pastel1_r, Pastel2, Pastel2_r, PiYG, PiYG_r, PuBu, PuBuGn, PuBuGn_r, PuBu_r, PuOr, PuOr_r, PuRd, PuRd_r, Purples, Purples_r, RdBu, RdBu_r, RdGy, RdGy_r, RdPu, RdPu_r, RdYlBu, RdYlBu_r, RdYlGn, RdYlGn_r, Reds, Reds_r, Set1, Set1_r, Set2, Set2_r, Set3, Set3_r, Spectral, Spectral_r, Wistia, Wistia_r, YlGn, YlGnBu, YlGnBu_r, YlGn_r, YlOrBr, YlOrBr_r, YlOrRd, YlOrRd_r...其中末尾加r是颜色取反。

你可能感兴趣的:(Python)