评价一个二分类模型的好坏有一些常见的标准。
真实值有两个取值:P和N,分别表示正样本和负样本。
假设值有两个取值:Y和N,同上。
首先统计下表作基础:
p n
Y TP FP
N FN TN
接下来的统计结果要建立在上述4个变量之上
ROC曲线:
横轴为FPR,纵轴为TPR。FPR=FP/N,TPR=TP/N,N表示总样本数。
绘制出曲线应该在y=x直线之上,曲线积分的结果就是AUC的值。AUC越大则系统分类性能越好。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
|
import
sys
import
math
import
matplotlib.pyplot as plt
if
__name__
=
=
"__main__"
:
n
=
100000
width
=
float
(sys.argv[
2
])
i
=
0.
x
=
[]
y
=
[]
tot
=
0
while
i <
=
1
:
tp
=
0
; fp
=
0
; fn
=
0
; tn
=
0
; p
=
0
fin
=
open
(sys.argv[
1
])
while
True
:
line
=
fin.readline()
if
line :
a, b
=
map
(
float
, line.split())
if
math.fabs(b
-
1.0
) <
1e
-
3
:
p
=
p
+
1
if
a < i :
fn
=
fn
+
1
else
:
tp
=
tp
+
1
else
:
if
a < i :
tn
=
tn
+
1
else
:
fp
=
fp
+
1
else
:
break
x.append(tp)
y.append(fp)
i
=
i
+
width
tot
=
tot
+
1
for
i
in
range
(tot) :
x[i]
=
x[i]
*
1.
/
p
y[i]
=
y[i]
/
100000.
plt.plot(y, x)
plt.axis([
0
,
1.1
,
0
,
1.1
])
plt.xlabel(
'FPR'
)
plt.ylabel(
'TPR'
)
plt.title(
'ROC curve'
)
plt.show()
|
PR曲线:
横轴为presicion,纵轴为recall。precision=TP/(TP+FP),recall=TP/(TP+FN),即传统意义上的准确率和召回率。绘制出曲线有些类似y=1/x。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
|
import
sys
import
math
import
matplotlib.pyplot as plt
import
matplotlib.lines as lines
if
__name__
=
=
"__main__"
:
n
=
100000
width
=
float
(sys.argv[
2
])
i
=
0.
x
=
[]
y
=
[]
tot
=
0
while
i <
=
1
:
tp
=
0
; fp
=
0
; fn
=
0
; tn
=
0
fin
=
open
(sys.argv[
1
])
while
True
:
line
=
fin.readline()
if
line :
a, b
=
map
(
float
, line.split())
if
math.fabs(b
-
1.0
) <
1e
-
3
:
if
a < i :
fn
=
fn
+
1
else
:
tp
=
tp
+
1
else
:
if
a < i :
tn
=
tn
+
1
else
:
fp
=
fp
+
1
else
:
break
x.append(tp
*
1.
/
(tp
+
fp))
y.append(tp
*
1.
/
(tp
+
fn))
i
=
i
+
width
tot
=
tot
+
1
plt.plot(x, y)
plt.axis([
0
,
1
,
0
,
1
])
plt.xlabel(
'precision'
)
plt.ylabel(
'recall'
)
plt.title(
'PR curve'
)
plt.show()
|
AUC:
即上述ROC曲线的积分结果。一般用近似结果代替,即正样本decision>负样本decision的概率。排个序扫一遍就行了。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
|
import
sys
import
math
import
matplotlib.pyplot as plt
if
__name__
=
=
"__main__"
:
n
=
100000
fin
=
open
(sys.argv[
1
])
x
=
[]
y
=
[]
while
True
:
line
=
fin.readline()
if
line :
a, b
=
map
(
float
, line.split())
if
math.fabs(b
-
1.0
) <
1e
-
3
:
x.append(a)
else
:
y.append(a)
else
:
break
x.sort()
y.sort()
sx
=
len
(x)
sy
=
len
(y)
j
=
0
tot
=
0
for
i
in
range
(sx) :
while
j < sy
and
x[i] > y[j] :
j
=
j
+
1
tot
+
=
j
tot
=
tot
*
1.
/
(sx
*
sy)
print
tot
|
例子:
先上两幅图;
二值图 显著图
Precision-recall :
ROC
MATLAB:
precision(threshNo) = tp / (tp+fp); recall(threshNo) = tp / (tp+fn); tpp(threshNo) = tp / (tp+fn); fpp(threshNo) = fp / (fp+tn);