python matplotlib绘制混淆矩阵并配色

文章目录

    • 步骤1:网络测试结果保存
    • 步骤2:矩阵绘制
    • 混淆矩阵绘制结果

步骤1:网络测试结果保存

以pytorch为例,在测试阶段保存结果的参考代码为:

resultTxtName = "result.txt"
resultfiledir = os.path.join(web_dir,resultTxtName)
f_result=open(resultfiledir, "a+")

if opt.which_model_netG[-5:] == "class":
	clss_loss_sum = 0
	right_numbers = 0
	for i, data in enumerate(dataset):
		# if i >= opt.how_many:
		# 	break
		counter = i + 1
		# pdb.set_trace()
		model.set_input(data)
		pred, flag, loss, right_number = model.test()
		right_numbers += right_number
		clss_loss_sum += loss
		new_result_context = str(pred.item()) + ';' + str(flag.item())+ '\n'
		f_result.write(new_result_context)
		# print(pred.item(),flag.item())
		if i % 100 == 0:
			print(clss_loss_sum/i,right_numbers / (i+1))
			f_result.close()
			f_result=open(resultfiledir, "a+")

步骤2:矩阵绘制

参考文章:
[1] https://blog.csdn.net/qq_33590958/article/details/103443215

对应代码:

from sklearn.metrics import confusion_matrix    # 生成混淆矩阵的函数
import numpy as np
from matplotlib import pyplot as plt
import pdb

plt.rcParams["font.sans-serif"]=["SimSun"]
'''
首先是从结果文件中读取预测标签与真实标签,然后将读取的标签信息传入python内置的混淆矩阵矩阵函数confusion_matrix(真实标签,
预测标签)中计算得到混淆矩阵,之后调用自己实现的混淆矩阵可视化函数plot_confusion_matrix()即可实现可视化。
三个参数分别是混淆矩阵归一化值,总的类别标签集合,可是化图的标题
'''
 
def plot_confusion_matrix(cm, labels_name, title):
    np.set_printoptions(precision=2)
    # print(cm)
    plt.imshow(cm, interpolation='nearest',cmap='YlOrBr')    # 在特定的窗口上显示图像
    # 显示text
    for first_index in range(len(cm)):    #第几行
        for second_index in range(len(cm[first_index])):    #第几列
            plt.text(first_index, second_index, '%.2f' % cm[first_index][second_index],
            horizontalalignment='center')
    plt.title(title)    # 图像标题
    plt.colorbar()
    num_local = np.array(range(len(labels_name)))    
    plt.xticks(num_local, labels_name, rotation=90)    # 将标签印在x轴坐标上
    plt.yticks(num_local, labels_name)    # 将标签印在y轴坐标上
    plt.ylabel('真实类别')    
    plt.xlabel('预测类别')
    # show confusion matrix
    plt.savefig('./fig/'+title+'.png', format='png')
 
gt = []
pre = []
with open("11c_result.txt", "r") as f:
    for line in f:
        line=line.rstrip()#rstrip() 删除 string 字符串末尾的指定字符(默认为空格)
        words=line.split(';')
        pre.append(int(words[0]))
        gt.append(int(eval(words[1])))
 
cm=confusion_matrix(gt,pre)  #计算混淆矩阵
print('type=',type(cm))
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]    # 归一化
labels = [0,1,2,3,4,5,6,7,8,9,10]  #类别集合
# pdb.set_trace()
plot_confusion_matrix(cm,labels,'隐藏场景分类探测任务的混淆矩阵')  #绘制混淆矩阵图,可视化

相比于文章[1],主要修改了如下几个地方:

  1. 学位论文要求图例中文宋体

plt.rcParams["font.sans-serif"]=["SimSun"]

  1. 为了美观,使用暖色调的colormap

plt.imshow(cm, interpolation='nearest',cmap='YlOrBr')

  1. 在混淆矩阵中写入了概率
    for first_index in range(len(cm)):    #第几行
        for second_index in range(len(cm[first_index])):    #第几列
            plt.text(first_index, second_index, '%.2f' % cm[first_index][second_index],
            horizontalalignment='center')

混淆矩阵绘制结果

python matplotlib绘制混淆矩阵并配色_第1张图片

你可能感兴趣的:(#,Deep,Learning,#,可视数据分析,其他,matplotlib,混淆矩阵,colormap)