PR曲线代码

from sklearn.metrics import roc_auc_score

def get_pr(oof,target):
    pos = target[target==1]
    threshold = np.sort(oof)[::-1]
    y = target[oof.argsort()[::-1]]
    y=y.reset_index(drop=True)
    recall = []
    precision = []
    tp = 0
    fp = 0
    auc = 0
    for i in range(len(threshold)):
        if y[i] == 1:
            tp += 1
            recall.append(tp/len(pos))
            precision.append(tp/(tp+fp))
        else:
            fp += 1
            recall.append(tp/len(pos))
            precision.append(tp/(tp+fp))
    
    auc = roc_auc_score(target, oof)
    return precision,recall,auc

precision_lgb,recall_lgb,auc_lgb = get_pr(oof_lgb,target)
precision_LR,recall_LR,auc_LR = get_pr(oof_LR,target)

plt.figure(figsize=(12,8))
plt.plot(recall_lgb,precision_lgb,label="LigntGBM (AUC: {:.3f})".format(auc_lgb),linewidth=3)
plt.plot(recall_LR,precision_LR,label="159 Logistic Regression (AUC: {:.3f})".format(auc_LR),linewidth=3)
plt.xlabel("Recall",fontsize=16)
plt.ylabel("Precision",fontsize=16)
plt.title("Precision Recall Curve",fontsize=17)
plt.legend(fontsize=16)

PR曲线代码_第1张图片

你可能感兴趣的:(代码,学习记录)