参考:https://blog.csdn.net/dipizhong7224/article/details/104579159
官方文档:https://github.com/scikit-learn/scikit-learn/blob/7f9bad99d6e0a3e8ddf92a7e5561245224dab102/sklearn/metrics/_classification.py#L1551
def classification_report_topk(y_true, y_pred, topk=1, labelnames=None, digits=2, output_dict=False,):
'''
y_true: [1,1,2,3]
y_pred: [[1,3],[3,2],[2,3],[1,2]]
labelnames: [1,2,3]
'''
assert topk <= len(y_pred[0]), 'topk out of bounds!'
if labelnames==None:
from sklearn.utils.multiclass import unique_labels
if type(y_pred)==list:
labelnames = unique_labels(y_true, sum(y_pred,[]))
elif type(y_pred)==numpy.ndarray:
labelnames = unique_labels(y_true, y_pred.flatten())
else:
labelnames = unique_labels(y_true, y_true)
rows = []
tp_sums = 0
y_pred=[each[0:topk] for each in y_pred]
for label in labelnames:
cur_res=[]
tp_fn=y_true.count(label)#TP+FN
#TP+FP
tp_fp=0
for i in y_pred:
if label in i:
tp_fp+=1
#TP
# 计算acc时需要使用tp
tp=0
for i in range(len(y_true)):
if y_true[i] == label and label in y_pred[i]:
tp+=1
tp_sums+=tp
support=tp_fn
try:
precision=tp/tp_fp
recall=tp/tp_fn
f1_score=2/((1/precision)+(1/recall))
except ZeroDivisionError:
precision=0.0
recall=0.0
f1_score=0.0
rows.append([str(label),precision,recall,f1_score, support])
accuracy_topk = tp_sums / len(y_true)
rows.append(['accuracy', accuracy_topk, accuracy_topk, accuracy_topk, len(y_true)])
average_options = ["macro", "weighted"]
weights_weighted = [rows[i][4] for i in range(len(rows)-1)]
weights_options = [None, weights_weighted]
precision = [row[1] for row in rows[:-1]]
recall = [row[2] for row in rows[:-1]]
f1_score = [row[3] for row in rows[:-1]]
for avg_name, weight in zip(average_options,weights_options):
p = np.average(precision,weights=weight)
r = np.average(recall,weights=weight)
f1 = np.average(f1_score,weights=weight)
rows.append([avg_name+' avg',p,r,f1,len(y_true)])
# print format
headers = ["precision", "recall", "f1-score", "support"]
if output_dict:
report_dict = {label[0]: label[1:] for label in rows}
for label, scores in report_dict.items():
report_dict[label] = dict(zip(headers, [float(i) for i in scores]))
return report_dict
else:
target_names = [rows[i][0] for i in range(len(rows))]
longest_last_line_heading = "weighted avg"
name_width = max(len(cn) for cn in target_names)
width = max(name_width, len(longest_last_line_heading), digits)
head_fmt = "{:>{width}s} " + " {:>9}" * len(headers)
report = head_fmt.format("", *headers, width=width)
report += "\n\n"
row_fmt = "{:>{width}s} " + " {:>9.{digits}f}" * 3 + " {:>9}\n"
for row in rows:
report += row_fmt.format(*row, width=width, digits=digits)
report += "\n"
return report