利用scikitlearn画ROC曲线

一个完整的数据挖掘模型,最后都要进行模型评估,对于二分类来说,AUCROC这两个指标用到最多,所以 利用sklearn里面相应的函数进行模块搭建。

具体实现的代码可以参照下面博友的代码,评估svm的分类指标。注意里面的一些细节需要注意,一个是调用roc_curve 方法时,指明目标标签,否则会报错。具体是这个参数的设置pos_label ,以前在unionbigdata实习时学到的。

重点是以下的代码需要根据实际改写:

    mean_tpr = 0.0  
    mean_fpr = np.linspace(0, 1, 100)  
    all_tpr = []
    
    y_target = np.r_[train_y,test_y]
    cv = StratifiedKFold(y_target, n_folds=6)

        #画ROC曲线和计算AUC
        fpr, tpr, thresholds = roc_curve(test_y, predict,pos_label = 2)##指定正例标签,pos_label = ###########在数之联的时候学到的,要制定正例
        
        mean_tpr += interp(mean_fpr, fpr, tpr)          #对mean_tpr在mean_fpr处进行插值,通过scipy包调用interp()函数  
        mean_tpr[0] = 0.0                               #初始处为0  
        roc_auc = auc(fpr, tpr)  
        #画图,只需要plt.plot(fpr,tpr),变量roc_auc只是记录auc的值,通过auc()函数能计算出来  
        plt.plot(fpr, tpr, lw=1, label='ROC  %s (area = %0.3f)' % (classifier, roc_auc)) 


然后是博友的参考代码:

[python]  view plain  copy
 
  1. # -*- coding: utf-8 -*-  
  2. """ 
  3. Created on Sun Apr 19 08:57:13 2015 
  4.  
  5. @author: shifeng 
  6. """  
  7. print(__doc__)  
  8.   
  9. import numpy as np  
  10. from scipy import interp  
  11. import matplotlib.pyplot as plt  
  12.   
  13. from sklearn import svm, datasets  
  14. from sklearn.metrics import roc_curve, auc  
  15. from sklearn.cross_validation import StratifiedKFold  
  16.   
  17. ###############################################################################  
  18. # Data IO and generation,导入iris数据,做数据准备  
  19.   
  20. # import some data to play with  
  21. iris = datasets.load_iris()  
  22. X = iris.data  
  23. y = iris.target  
  24. X, y = X[y != 2], y[y != 2]#去掉了label为2,label只能二分,才可以。  
  25. n_samples, n_features = X.shape  
  26.   
  27. # Add noisy features  
  28. random_state = np.random.RandomState(0)  
  29. X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]  
  30.   
  31. ###############################################################################  
  32. # Classification and ROC analysis  
  33. #分类,做ROC分析  
  34.   
  35. # Run classifier with cross-validation and plot ROC curves  
  36. #使用6折交叉验证,并且画ROC曲线  
  37. cv = StratifiedKFold(y, n_folds=6)  
  38. classifier = svm.SVC(kernel='linear', probability=True,  
  39.                      random_state=random_state)#注意这里,probability=True,需要,不然预测的时候会出现异常。另外rbf核效果更好些。  
  40.   
  41. mean_tpr = 0.0  
  42. mean_fpr = np.linspace(01100)  
  43. all_tpr = []  
  44.   
  45. for i, (train, test) in enumerate(cv):  
  46.     #通过训练数据,使用svm线性核建立模型,并对测试集进行测试,求出预测得分  
  47.     probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test])  
  48. #    print set(y[train])                     #set([0,1]) 即label有两个类别  
  49. #    print len(X[train]),len(X[test])        #训练集有84个,测试集有16个  
  50. #    print "++",probas_                      #predict_proba()函数输出的是测试集在lael各类别上的置信度,  
  51. #    #在哪个类别上的置信度高,则分为哪类  
  52.     # Compute ROC curve and area the curve  
  53.     #通过roc_curve()函数,求出fpr和tpr,以及阈值  
  54.     fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1])  
  55.     mean_tpr += interp(mean_fpr, fpr, tpr)          #对mean_tpr在mean_fpr处进行插值,通过scipy包调用interp()函数  
  56.     mean_tpr[0] = 0.0                               #初始处为0  
  57.     roc_auc = auc(fpr, tpr)  
  58.     #画图,只需要plt.plot(fpr,tpr),变量roc_auc只是记录auc的值,通过auc()函数能计算出来  
  59.     plt.plot(fpr, tpr, lw=1, label='ROC fold %d (area = %0.2f)' % (i, roc_auc))  
  60.   
  61. #画对角线  
  62. plt.plot([01], [01], '--', color=(0.60.60.6), label='Luck')  
  63.   
  64. mean_tpr /= len(cv)                     #在mean_fpr100个点,每个点处插值插值多次取平均  
  65. mean_tpr[-1] = 1.0                      #坐标最后一个点为(1,1)  
  66. mean_auc = auc(mean_fpr, mean_tpr)      #计算平均AUC值  
  67. #画平均ROC曲线  
  68. #print mean_fpr,len(mean_fpr)  
  69. #print mean_tpr  
  70. plt.plot(mean_fpr, mean_tpr, 'k--',  
  71.          label='Mean ROC (area = %0.2f)' % mean_auc, lw=2)  
  72.   
  73. plt.xlim([-0.051.05])  
  74. plt.ylim([-0.051.05])  
  75. plt.xlabel('False Positive Rate')  
  76. plt.ylabel('True Positive Rate')  
  77. plt.title('Receiver operating characteristic example')  
  78. plt.legend(loc="lower right")  
  79. plt.show()  

你可能感兴趣的:(Python语法相关,数据挖掘)