classification_report加入topk计算

参考:https://blog.csdn.net/dipizhong7224/article/details/104579159
官方文档:https://github.com/scikit-learn/scikit-learn/blob/7f9bad99d6e0a3e8ddf92a7e5561245224dab102/sklearn/metrics/_classification.py#L1551

def classification_report_topk(y_true, y_pred, topk=1, labelnames=None, digits=2, output_dict=False,):
    '''
    y_true: [1,1,2,3]
    y_pred: [[1,3],[3,2],[2,3],[1,2]]
    labelnames: [1,2,3]
    '''
    assert topk <= len(y_pred[0]), 'topk out of bounds!'
    if labelnames==None:
        from sklearn.utils.multiclass import unique_labels
        if type(y_pred)==list:
            labelnames = unique_labels(y_true, sum(y_pred,[]))
        elif type(y_pred)==numpy.ndarray:
            labelnames = unique_labels(y_true, y_pred.flatten())
        else:
            labelnames = unique_labels(y_true, y_true)
    rows = []
    tp_sums = 0
    y_pred=[each[0:topk] for each in y_pred]
    for label in labelnames:
        cur_res=[]
        tp_fn=y_true.count(label)#TP+FN
        #TP+FP
        tp_fp=0
        for i in y_pred:
            if label in i:
                tp_fp+=1
        #TP
        # 计算acc时需要使用tp
        tp=0
        for i in range(len(y_true)):
            if y_true[i] == label and label in y_pred[i]:
                tp+=1
        tp_sums+=tp
        support=tp_fn
        try:
            precision=tp/tp_fp
            recall=tp/tp_fn
            f1_score=2/((1/precision)+(1/recall))
        except ZeroDivisionError:
            precision=0.0
            recall=0.0
            f1_score=0.0
        rows.append([str(label),precision,recall,f1_score, support])

    accuracy_topk = tp_sums / len(y_true)
    rows.append(['accuracy', accuracy_topk, accuracy_topk, accuracy_topk, len(y_true)])
    
    average_options = ["macro", "weighted"]
    
    weights_weighted = [rows[i][4] for i in range(len(rows)-1)]
    weights_options = [None, weights_weighted]
    precision = [row[1] for row in rows[:-1]]
    recall = [row[2] for row in rows[:-1]]
    f1_score = [row[3] for row in rows[:-1]]
    for avg_name, weight in zip(average_options,weights_options):
        p = np.average(precision,weights=weight)
        r = np.average(recall,weights=weight)
        f1 = np.average(f1_score,weights=weight)
        rows.append([avg_name+' avg',p,r,f1,len(y_true)])
    
    # print format
    headers = ["precision", "recall", "f1-score", "support"]
    if output_dict:
        report_dict = {label[0]: label[1:] for label in rows}
        for label, scores in report_dict.items():
            report_dict[label] = dict(zip(headers, [float(i) for i in scores]))
        return report_dict
    else:
        target_names = [rows[i][0] for i in range(len(rows))]
        longest_last_line_heading = "weighted avg"
        name_width = max(len(cn) for cn in target_names)
        width = max(name_width, len(longest_last_line_heading), digits)
        head_fmt = "{:>{width}s} " + " {:>9}" * len(headers)
        report = head_fmt.format("", *headers, width=width)
        report += "\n\n"
        row_fmt = "{:>{width}s} " + " {:>9.{digits}f}" * 3 + " {:>9}\n"
        for row in rows:
            report += row_fmt.format(*row, width=width, digits=digits)
        report += "\n"
        return report

你可能感兴趣的:(scikit-learn,python)