以pytorch为例,在测试阶段保存结果的参考代码为:
resultTxtName = "result.txt"
resultfiledir = os.path.join(web_dir,resultTxtName)
f_result=open(resultfiledir, "a+")
if opt.which_model_netG[-5:] == "class":
clss_loss_sum = 0
right_numbers = 0
for i, data in enumerate(dataset):
# if i >= opt.how_many:
# break
counter = i + 1
# pdb.set_trace()
model.set_input(data)
pred, flag, loss, right_number = model.test()
right_numbers += right_number
clss_loss_sum += loss
new_result_context = str(pred.item()) + ';' + str(flag.item())+ '\n'
f_result.write(new_result_context)
# print(pred.item(),flag.item())
if i % 100 == 0:
print(clss_loss_sum/i,right_numbers / (i+1))
f_result.close()
f_result=open(resultfiledir, "a+")
参考文章:
[1] https://blog.csdn.net/qq_33590958/article/details/103443215
对应代码:
from sklearn.metrics import confusion_matrix # 生成混淆矩阵的函数
import numpy as np
from matplotlib import pyplot as plt
import pdb
plt.rcParams["font.sans-serif"]=["SimSun"]
'''
首先是从结果文件中读取预测标签与真实标签,然后将读取的标签信息传入python内置的混淆矩阵矩阵函数confusion_matrix(真实标签,
预测标签)中计算得到混淆矩阵,之后调用自己实现的混淆矩阵可视化函数plot_confusion_matrix()即可实现可视化。
三个参数分别是混淆矩阵归一化值,总的类别标签集合,可是化图的标题
'''
def plot_confusion_matrix(cm, labels_name, title):
np.set_printoptions(precision=2)
# print(cm)
plt.imshow(cm, interpolation='nearest',cmap='YlOrBr') # 在特定的窗口上显示图像
# 显示text
for first_index in range(len(cm)): #第几行
for second_index in range(len(cm[first_index])): #第几列
plt.text(first_index, second_index, '%.2f' % cm[first_index][second_index],
horizontalalignment='center')
plt.title(title) # 图像标题
plt.colorbar()
num_local = np.array(range(len(labels_name)))
plt.xticks(num_local, labels_name, rotation=90) # 将标签印在x轴坐标上
plt.yticks(num_local, labels_name) # 将标签印在y轴坐标上
plt.ylabel('真实类别')
plt.xlabel('预测类别')
# show confusion matrix
plt.savefig('./fig/'+title+'.png', format='png')
gt = []
pre = []
with open("11c_result.txt", "r") as f:
for line in f:
line=line.rstrip()#rstrip() 删除 string 字符串末尾的指定字符(默认为空格)
words=line.split(';')
pre.append(int(words[0]))
gt.append(int(eval(words[1])))
cm=confusion_matrix(gt,pre) #计算混淆矩阵
print('type=',type(cm))
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] # 归一化
labels = [0,1,2,3,4,5,6,7,8,9,10] #类别集合
# pdb.set_trace()
plot_confusion_matrix(cm,labels,'隐藏场景分类探测任务的混淆矩阵') #绘制混淆矩阵图,可视化
相比于文章[1],主要修改了如下几个地方:
plt.rcParams["font.sans-serif"]=["SimSun"]
plt.imshow(cm, interpolation='nearest',cmap='YlOrBr')
for first_index in range(len(cm)): #第几行
for second_index in range(len(cm[first_index])): #第几列
plt.text(first_index, second_index, '%.2f' % cm[first_index][second_index],
horizontalalignment='center')