import matplotlib.pyplot as plt
import numpy as np
1、统计训练和测试精度
def data_plot(path):
with open(path, mode="r", encoding="utf-8") as f:
data = f.readlines()
Train = []
Test = []
for item in data:
if "Train Acc" in item:
train = item.split(":")[1].strip()
if train == "":
continue
Train.append(float(train))
elif "Test Acc" in item:
test = item.split(":")[1].strip()
if test == "":
continue
Test.append(float(test))
plt.title('train')
plt.plot(Train)
plt.clf()
plt.title('test')
plt.plot(Test)
if __name__ == '__main__':
path = r"./[6, 6, 6, 6]_Standard_taz.csv"
data_plot(path)
2、计算混淆矩阵的精度、召回率和f1
def calculate_prediction(metrix):
"""
计算精度
"""
label_pre = []
current_sum = 0
for i in range(metrix.shape[0]):
current_sum += metrix[i][i]
label_total_sum = metrix.sum(axis=0)[i]
pre = round(100 * metrix[i][i] / label_total_sum, 4)
label_pre.append(pre)
print("每类精度:", label_pre)
all_pre = round(100 * current_sum / metrix.sum(), 4)
print("总精度:", all_pre)
return label_pre, all_pre
def calculate_recall(metrix):
"""
先计算某一个类标的召回率;
再计算出总体召回率
"""
label_recall = []
for i in range(metrix.shape[0]):
label_total_sum = metrix.sum(axis=1)[i]
label_correct_sum = metrix[i][i]
recall = 0
if label_total_sum != 0:
recall = round(100 * float(label_correct_sum) / float(label_total_sum), 4)
label_recall.append(recall)
print("每类召回率:", label_recall)
all_recall = round(np.array(label_recall).sum() / metrix.shape[0], 4)
print("总召回率:", all_recall)
return label_recall, all_recall
def calculate_f1(prediction, all_pre, recall, all_recall):
"""
计算f1分数
"""
all_f1 = []
for i in range(len(prediction)):
pre, reca = prediction[i], recall[i]
f1 = 0
if (pre + reca) != 0:
f1 = round(2 * pre * reca / (pre + reca), 4)
all_f1.append(f1)
print("每类f1:", all_f1)
print("总的f1:", round(2 * all_pre * all_recall / (all_pre + all_recall), 4))
return all_f1
if __name__ == '__main__':
metrix = \
np.array([[84, 30, 16, 4, 4],
[11, 88, 14, 5, 1],
[13, 31, 75, 0, 0],
[12, 15, 3, 71, 1],
[31, 7, 5, 12, 67]])
print(metrix.sum(axis=0)[0], metrix.sum(axis=1)[0])
label_pre, all_pre = calculate_prediction(metrix)
label_recall, all_recall = calculate_recall(metrix)
calculate_f1(label_pre, all_pre, label_recall, all_recall)
3、绘制混淆矩阵展示图形,已经混淆矩阵平均值
def get_Confusion_matrix(path):
numCount = 200
with open(path, mode="r", encoding="utf-8") as f:
data = f.readlines()
Confusion = []
epoch_con = []
for item in data:
if ("Train" in item) or ("Test" in item):
continue
if "[[" in item:
epoch_con = []
datas = list((item.strip()[2:-1]).split())
epoch_con.append(datas)
continue
if "]]" in item:
datas = list((item.strip()[1:-2]).split())
epoch_con.append(datas)
Confusion.append(epoch_con)
continue
if "[" in item:
datas = list((item.strip()[1:-1]).split())
epoch_con.append(datas)
sum = np.zeros((5, 5), dtype=int)
for temp in Confusion[-numCount:]:
print(temp)
sum += np.array(temp, dtype=int)
metrix = sum / numCount
print(metrix)
plot_Confusion_matrix(metrix=metrix)
print(metrix.sum(axis=0)[0], metrix.sum(axis=1)[0])
label_pre, all_pre = calculate_prediction(metrix)
label_recall, all_recall = calculate_recall(metrix)
calculate_f1(label_pre, all_pre, label_recall, all_recall)
if __name__ == '__main__':
path = r"C:\Users\Administrator\Desktop\nj单一特征对比实验\12-16号graphsage的结果" \
r"\lstm\[7, 7, 7, 7]_Standard_taz_lstm.csv"
get_Confusion_matrix(path)