该算法实例中包括了多模型机器学习比较的基本流程:
注:数据需要替换为自己的数据集
#导入库
from sklearn.metrics import confusion_matrix,accuracy_score,f1_score,roc_auc_score,recall_score,precision_score,roc_curve
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import KFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.naive_bayes import GaussianNB
from sklearn.model_selection import cross_val_score
from matplotlib import pyplot
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from sklearn.model_selection import cross_validate
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import pandas as pd
#数据集分割
X_train,X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=8675309)
#绘制roc曲线
def calculate_auc(y_test, pred):
print("auc:",roc_auc_score(y_test, pred))
fpr, tpr, thersholds = roc_curve(y_test, pred)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, 'k-', label='ROC (area = {0:.2f})'.format(roc_auc),color='blue', lw=2)
plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc="lower right")
plt.plot([0, 1], [0, 1], 'k--')
plt.show()
#使用Yooden法寻找最佳阈值
def Find_Optimal_Cutoff(TPR, FPR, threshold):
y = TPR - FPR
Youden_index = np.argmax(y) # Only the first occurrence is returned.
optimal_threshold = threshold[Youden_index]
point = [FPR[Youden_index], TPR[Youden_index]]
return optimal_threshold, point
#计算roc值
def ROC(label, y_prob):
fpr, tpr, thresholds = roc_curve(label, y_prob)
roc_auc = auc(fpr, tpr)
optimal_threshold, optimal_point = Find_Optimal_Cutoff(TPR=tpr, FPR=fpr, threshold=thresholds)
return fpr, tpr, roc_auc, optimal_threshold, optimal_point
#计算混淆矩阵
def calculate_metric(label, y_prob,optimal_threshold):
p=[]
for i in y_prob:
if i>=optimal_threshold:
p.append(1)
else:
p.append(0)
confusion = confusion_matrix(label,p)
print(confusion)
TP = confusion[1, 1]
TN = confusion[0, 0]
FP = confusion[0, 1]
FN = confusion[1, 0]
Accuracy=(TP+TN)/float(TP+TN+FP+FN)
Sensitivity=TP / float(TP+FN)
Specificity=TN / float(TN+FP)
return Accuracy,Sensitivity,Specificity
#多模型比较:
models = [ ('Logit', LogisticRegression(max_iter=5000)),
('KNN', KNeighborsClassifier()),
('SVM', SVC(probability=True)),
('GNB', GaussianNB()),
('DT',DecisionTreeClassifier(random_state=0)),
('RF',RandomForestClassifier(max_depth=2, random_state=0))]
#循环训练模型
results=[]
roc_=[]
for name,model in models:
clf=model.fit(X_train,y_train)
pred_proba = clf.predict_proba(X_test)
y_prob=pred_proba[:,1]
fpr, tpr, roc_auc, Optimal_threshold, optimal_point=ROC(y_test, y_prob)
Accuracy,Sensitivity,Specificity=calculate_metric(y_test, y_prob,optimal_threshold)
result=[Optimal_threshold,Accuracy,Sensitivity,Specificity,roc_auc,name]
results.append(result)
roc_.append([fpr,tpr,roc_auc,name])
df_result=pd.DataFrame(results)
df_result.columns=["Optimal_threshold","Accuracy","Sensitivity","Specificity","AUC_ROC","Model_name"]
#绘制多组对比roc曲线
color=["darkorange","navy","red","green","yellow","pink"]
plt.figure()
plt.figure(figsize=(10,10))
lw = 2
plt.plot(roc_[0][0], roc_[0][1], color=color[0], lw=lw, label=roc_[0][3]+' (AUC = %0.3f)' % roc_[0][2])
plt.plot(roc_[1][0], roc_[1][1], color=color[1], lw=lw, label=roc_[1][3]+' (AUC = %0.3f)' % roc_[1][2])
plt.plot(roc_[2][0], roc_[2][1], color=color[2], lw=lw, label=roc_[2][3]+' (AUC = %0.3f)' % roc_[2][2])
plt.plot(roc_[3][0], roc_[3][1], color=color[3], lw=lw, label=roc_[3][3]+' (AUC = %0.3f)' % roc_[3][2])
plt.plot(roc_[4][0], roc_[4][1], color=color[4], lw=lw, label=roc_[4][3]+' (AUC = %0.3f)' % roc_[4][2])
plt.plot(roc_[5][0], roc_[5][1], color=color[5], lw=lw, label=roc_[5][3]+' (AUC = %0.3f)' % roc_[5][2])
plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic Curve')
plt.legend(loc="lower right")
plt.savefig("roc_curve.png",dpi=300)
plt.show()
参考链接: