matplotlib绘制混淆矩阵_如何绘制混淆矩阵?

I am using scikit-learn for classification of text documents(22000) to 100 classes. I use scikit-learn's confusion matrix method for computing the confusion matrix.

model1 = LogisticRegression()

model1 = model1.fit(matrix, labels)

pred = model1.predict(test_matrix)

cm=metrics.confusion_matrix(test_labels,pred)

print(cm)

plt.imshow(cm, cmap='binary')

This is how my confusion matrix looks like:

[[3962 325 0 ..., 0 0 0]

[ 250 2765 0 ..., 0 0 0]

[ 2 8 17 ..., 0 0 0]

...,

[ 1 6 0 ..., 5 0 0]

[ 1 1 0 ..., 0 0 0]

[ 9 0 0 ..., 0 0 9]]

However, I do not receive a clear or legible plot. Is there a better way to do this?

解决方案

you can use plt.matshow() instead of plt.imshow() or you can use seaborn module's heatmap (see documentation) to plot the confusion matrix

import seaborn as sn

import pandas as pd

import matplotlib.pyplot as plt

array = [[33,2,0,0,0,0,0,0,0,1,3],

[3,31,0,0,0,0,0,0,0,0,0],

[0,4,41,0,0,0,0,0,0,0,1],

[0,1,0,30,0,6,0,0,0,0,1],

[0,0,0,0,38,10,0,0,0,0,0],

[0,0,0,3,1,39,0,0,0,0,4],

[0,2,2,0,4,1,31,0,0,0,2],

[0,1,0,0,0,0,0,36,0,2,0],

[0,0,0,0,0,0,1,5,37,5,1],

[3,0,0,0,0,0,0,0,0,39,0],

[0,0,0,0,0,0,0,0,0,0,38]]

df_cm = pd.DataFrame(array, index = [i for i in "ABCDEFGHIJK"],

columns = [i for i in "ABCDEFGHIJK"])

plt.figure(figsize = (10,7))

sn.heatmap(df_cm, annot=True)

你可能感兴趣的:(matplotlib绘制混淆矩阵_如何绘制混淆矩阵?)