ROC Curve, PR Curve and AUC

评价一个二分类模型的好坏有一些常见的标准。

真实值有两个取值:P和N,分别表示正样本和负样本。

假设值有两个取值:Y和N,同上。

首先统计下表作基础:

p      n

Y   TP    FP

N   FN    TN

接下来的统计结果要建立在上述4个变量之上

ROC曲线:

横轴为FPR,纵轴为TPR。FPR=FP/N,TPR=TP/N,N表示总样本数。

绘制出曲线应该在y=x直线之上,曲线积分的结果就是AUC的值。AUC越大则系统分类性能越好。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import sys
import math
import matplotlib.pyplot as plt
 
if __name__ = = "__main__" :
     n = 100000
     width = float (sys.argv[ 2 ])
     i = 0.
     x = []
     y = []
     tot = 0
     while i < = 1 :
         tp = 0 ; fp = 0 ; fn = 0 ; tn = 0 ; p = 0
         fin = open (sys.argv[ 1 ])
         while True :
             line = fin.readline()
             if line :
                 a, b = map ( float , line.split())
                 if math.fabs(b - 1.0 ) < 1e - 3 :
                     p = p + 1
                     if a < i :
                         fn = fn + 1
                     else :
                         tp = tp + 1
                 else :
                     if a < i :
                         tn = tn + 1
                     else :
                         fp = fp + 1
             else :
                 break
         x.append(tp)
         y.append(fp)
         i = i + width
         tot = tot + 1
     for i in range (tot) :
         x[i] = x[i] * 1. / p
         y[i] = y[i] / 100000.
     plt.plot(y, x)
     plt.axis([ 0 , 1.1 , 0 , 1.1 ])
     plt.xlabel( 'FPR' )
     plt.ylabel( 'TPR' )
     plt.title( 'ROC curve' )
     plt.show()

PR曲线:

横轴为presicion,纵轴为recall。precision=TP/(TP+FP),recall=TP/(TP+FN),即传统意义上的准确率和召回率。绘制出曲线有些类似y=1/x。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import sys
import math
import matplotlib.pyplot as plt
import matplotlib.lines as lines
 
if __name__ = = "__main__" :
     n = 100000
     width = float (sys.argv[ 2 ])
     i = 0.
     x = []
     y = []
     tot = 0
     while i < = 1 :
         tp = 0 ; fp = 0 ; fn = 0 ; tn = 0
         fin = open (sys.argv[ 1 ])
         while True :
             line = fin.readline()
             if line :
                 a, b = map ( float , line.split())
                 if math.fabs(b - 1.0 ) < 1e - 3 :
                     if a < i :
                         fn = fn + 1
                     else :
                         tp = tp + 1
                 else :
                     if a < i :
                         tn = tn + 1
                     else :
                         fp = fp + 1
             else :
                 break
         x.append(tp * 1. / (tp + fp))
         y.append(tp * 1. / (tp + fn))
         i = i + width
         tot = tot + 1
     plt.plot(x, y)
     plt.axis([ 0 , 1 , 0 , 1 ])
     plt.xlabel( 'precision' )
     plt.ylabel( 'recall' )
     plt.title( 'PR curve' )
     plt.show()

AUC:

即上述ROC曲线的积分结果。一般用近似结果代替,即正样本decision>负样本decision的概率。排个序扫一遍就行了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import sys
import math
import matplotlib.pyplot as plt
 
if __name__ = = "__main__" :
     n = 100000
     fin = open (sys.argv[ 1 ])
     x = []
     y = []
     while True :
         line = fin.readline()
         if line :
             a, b = map ( float , line.split())
             if math.fabs(b - 1.0 ) < 1e - 3 :
                 x.append(a)
             else :
                 y.append(a)
         else :
             break
     x.sort()
     y.sort()
     sx = len (x)
     sy = len (y)
     j = 0
     tot = 0
     for i in range (sx) :
         while j < sy and x[i] > y[j] :
             j = j + 1
         tot + = j
     tot = tot * 1. / (sx * sy)
     print tot


例子:

先上两幅图;

ROC Curve, PR Curve and AUC_第1张图片 ROC Curve, PR Curve and AUC_第2张图片

                                     二值图                                                                                  显著图

 

 

 

 Precision-recall :

 

ROC

 

MATLAB:

 

 precision(threshNo) = tp / (tp+fp);
    recall(threshNo) = tp / (tp+fn);
    
    tpp(threshNo) = tp / (tp+fn);
    fpp(threshNo) = fp / (fp+tn);


 

 

 

 

 

你可能感兴趣的:(ROC Curve, PR Curve and AUC)