有的时候我们需要将ROC曲线输出在同一张图中,这样可以更加直观地对比模型;并且我们常常会遇到在图形中有文字相互遮挡的问题,我们可以用adjustText中的adjust_text来实现文本不相互遮挡并添加箭头的功能。
def multi_models_roc(names, prob_results, colors,linestyles, y_test, save=True, dpin=100):
"""
将多个机器模型的roc图输出到一张图上
Args:
names: list, 多个模型的名称
prob_results: 使用模型预测的概率值(predict_proba()函数的返回值)
colors: 想绘制的曲线的颜色列表
linestyles: 想绘制的曲线的线型
save: 选择是否将结果保存(默认为png格式)
Returns:
返回图片对象plt
"""
plt.figure(figsize=(10, 10), dpi=dpin)
from adjustText import adjust_text
texts = []
for (name, result, colorname,linestylename) in zip(names, prob_results, colors, linestyles):
y_test_predprob = result[:,1]
fpr, tpr, thresholds = roc_curve(y_test, y_test_predprob)
optimal_th, optimal_point = Find_Optimal_Cutoff(TPR=tpr, FPR=fpr, threshold=thresholds)
# plt.plot(optimal_point[0], optimal_point[1], marker='o', color='r')
# texts.append(plt.text(optimal_point[0], optimal_point[1], name+' '+f'Threshold:{optimal_th:.2f}'))
texts.append(plt.text(optimal_point[0], optimal_point[1], name))
plt.plot(fpr, tpr, lw=3, label='{} (AUC={:.3f})'.format(name, auc(fpr, tpr)),color = colorname,linestyle=linestylename)
plt.plot([0, 1], [0, 1], '--', lw=3, color = 'grey')
plt.axis('square')
plt.xlim([0, 1])
plt.ylim([0, 1.05])
plt.xlabel('False Positive Rate',fontsize=10)
plt.ylabel('True Positive Rate',fontsize=10)
plt.title('ROC Curve',fontsize=20)
plt.legend(loc='lower right',fontsize=10)
adjust_text(texts,
arrowprops=dict(
arrowstyle='->',#箭头样式
lw= 2,#线宽
color='red')#箭头颜色
)
if save:
plt.savefig('multi_models_roc.png')
return plt
names = ['Logistic Regression',
'Naive Bayes',
'Decision Tree',
'Random Forest',
'SVM',
'Neural Network',
'GBDT',
'LightGBM',
'XGBoost'
]
#这是各个模型的预测值返回列表
prob_results = [lg_y_prob,
nb_y_prob,
tree_y_prob,
rf_y_prob,
svm_y_prob,
bp_y_prob,
gbdt_y_prob,
lgb_y_prob,
xgb_y_prob
]
colors = ['crimson',
'orange',
'gold',
'mediumseagreen',
'steelblue',
'mediumpurple' ,
'black',
'silver',
'navy'
]
linestyles = ['-', '--', '-.', ':', 'dotted', 'dashdot', '--', 'solid', 'dashed']
#ROC curves
train_roc_graph = multi_models_roc(names, prob_results, colors, linestyles, Y_test_smo_tmo, save = True)
train_roc_graph.savefig('ROC_Train_all.png')