ROC曲线和PR曲线是机器学习中两个常见的评估指标(对于二分器而言),做个笔记…
在二分类问题中,分类器将一个实例的分类标记为是或否,这可以用一个混淆矩阵来表示。混淆矩阵有四个分类,如下表:
TP(True Positive):指正确分类的正样本数,即预测为正样本,实际也是正样本。
FP(False Positive):指被错误的标记为正样本的负样本数,即实际为负样本而被预测为正样本,所以是False。
TN(True Negative):指正确分类的负样本数,即预测为负样本,实际也是负样本。
FN(False Negative):指被错误的标记为负样本的正样本数,即实际为正样本而被预测为负样本,所以是False。
Precision=TP/(TP+FP) —— (正确分类的正样本)/ (预测为正的总样本)
Recall=TP/(TP+FN) —— (正确分类的正样本)/ (实际为正的总样本)
TPR=TP/(TP+FN)=Recall # 真正例率
FPR=FP/(TN+FP) #假正例率
1、 ROC曲线由于兼顾正例与负例,所以适用于评估分类器的整体性能,相比而言PR曲线完全聚焦于正例。
2、如果有多份数据且存在不同的类别分布,比如信用卡欺诈问题中每个月正例和负例的比例可能都不相同,这时候如果只想单纯地比较分类器的性能且剔除类别分布改变的影响,则ROC曲线比较适合,因为类别分布改变可能使得PR曲线发生变化时好时坏,这种时候难以进行模型比较;反之,如果想测试不同类别分布下对分类器的性能的影响,则PR曲线比较适合。
3、如果想要评估在相同的类别分布下正例的预测情况,则宜选PR曲线。
4、类别不平衡问题中,ROC曲线通常会给出一个乐观的效果估计,所以大部分时候还是PR曲线更好。
5、最后,可以根据具体的应用,在曲线上找到最优的点,得到相对应的precision,recall,f1 score等指标,去调整模型的阈值,从而得到一个符合具体应用的模型
matplotlib>=2.0.2
numpy>=1.13.0
opencv-python>=3.3.1+contrib
tqdm>=4.19.4
# -*- coding:utf-8 -*-
import glob
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
class CollectData:
def __init__(self):
self.TP = []
self.FP = []
self.FN = []
self.TN = []
def reload(self,groundtruth,probgraph):
"""
:param groundtruth: list,groundtruth image list
:param probgraph: list,prob image list
:return: None
"""
self.groundtruth = groundtruth
self.probgraph = probgraph
self.TP = []
self.FP = []
self.FN = []
self.TN = []
def statistics(self):
"""
calculate FPR TPR Precision Recall IoU
:return: (FPR,TPR,AUC),(Precision,Recall,MAP),IoU
"""
for threshold in tqdm(range(0,255)):
temp_TP=0.0
temp_FP=0.0
temp_FN=0.0
temp_TN=0.0
assert(len(self.groundtruth)==len(self.probgraph))
for index in range(len(self.groundtruth)):
gt_img=cv2.imread(self.groundtruth[index])[:,:,0]
prob_img=cv2.imread(self.probgraph[index])[:,:,0]
gt_img=(gt_img>0)*1
prob_img=(prob_img>=threshold)*1
temp_TP = temp_TP + (np.sum(prob_img * gt_img))
temp_FP = temp_FP + np.sum(prob_img * ((1 - gt_img)))
temp_FN = temp_FN + np.sum(((1 - prob_img)) * ((gt_img)))
temp_TN = temp_TN + np.sum(((1 - prob_img)) * (1 - gt_img))
self.TP.append(temp_TP)
self.FP.append(temp_FP)
self.FN.append(temp_FN)
self.TN.append(temp_TN)
self.TP = np.asarray(self.TP).astype('float32')
self.FP = np.asarray(self.FP).astype('float32')
self.FN = np.asarray(self.FN).astype('float32')
self.TN = np.asarray(self.TN).astype('float32')
FPR = (self.FP) / (self.FP + self.TN)
TPR = (self.TP) / (self.TP + self.FN)
AUC = np.round(np.sum((TPR[1:] + TPR[:-1]) * (FPR[:-1] - FPR[1:])) / 2., 4)
Precision = self.TP / (self.TP + self.FP)
Recall = self.TP / (self.TP + self.FN)
MAP = np.round(np.sum((Precision[1:] + Precision[:-1]) * (Recall[:-1] - Recall[1:])) / 2.,4)
iou=self.IOU()
return (FPR,TPR,AUC),(Precision,Recall,MAP),iou
def IoU(self,threshold=128):
"""
to calculate IoU
:param threshold: numerical,a threshold for gray image to binary image
:return: IoU
"""
intersection=0.0
union=0.0
for index in range(len(self.groundtruth)):
gt_img = cv2.imread(self.groundtruth[index])[:, :, 0]
prob_img = cv2.imread(self.probgraph[index])[:, :, 0]
gt_img = (gt_img > 0) * 1
prob_img = (prob_img >= threshold) * 1
intersection=intersection+np.sum(gt_img*prob_img)
union=union+np.sum(gt_img)+np.sum(prob_img)-np.sum(gt_img*prob_img)
iou=np.round(intersection/union,4)
return iou
def debug(self):
"""
show debug info
:return: None
"""
print("Now enter debug mode....\nPlease check the info bellow:")
print("total groundtruth: %d total probgraph: %d\n"%(len(self.groundtruth),len(self.probgraph)))
for index in range(len(self.groundtruth)):
print(self.groundtruth[index],self.probgraph[index])
print("Please confirm the groundtruth and probgraph name is opposite")
class DrawCurve:
"""
draw ROC/PR curve
"""
def __init__(self,savepath):
self.savepath=savepath
self.colorbar=['red','green','blue','black']
self.linestyle=['-','-.','--',':','-*']
def reload(self,xdata,ydata,auc,dataName,modelName):
"""
this function is to update data for Function roc/pr to draw
:param xdata: list,x-coord of roc(pr)
:param ydata: list,y-coord of roc(pr)
:param auc: numerical,area under curve
:param dataName: string,name of dataset
:param modelName: string,name of test model
:return: None
"""
self.xdata.append(xdata)
self.ydata.append(ydata)
self.modelName.append(modelName)
self.auc.append(auc)
self.dataName=dataName
def newly(self,modelnum):
"""
renew all the data
:param modelnum: numerical,number of models to draw
:return: None
"""
self.modelnum = modelnum
self.xdata = []
self.ydata = []
self.modelName = []
self.auc = []
def roc(self):
"""
draw ROC curve,save the curve graph to savepath
:return: None
"""
plt.figure(1)
plt.title('ROC Curve of %s'%self.dataName, fontsize=15)
plt.xlabel("False Positive Rate", fontsize=15)
plt.ylabel("True Positive Rate", fontsize=15)
plt.xlim(0, 1)
plt.ylim(0, 1)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
for i in range(self.modelnum):
plt.plot(self.xdata[i], self.ydata[i], color=self.colorbar[i%len(self.colorbar)], linewidth=2.0, linestyle=self.linestyle[i%len(self.linestyle)], label=self.modelName[i]+',AUC:' + str(self.auc[i]))
plt.legend()
plt.savefig(self.savepath+'%s_ROC.png'%self.dataName, dpi=800)
#plt.show()
def pr(self):
"""
draw PR curve,save the curve to savepath
:return: None
"""
plt.figure(2)
plt.title('PR Curve of %s'%self.dataName, fontsize=15)
plt.xlabel("Recall", fontsize=15)
plt.ylabel("Precision", fontsize=15)
plt.xlim(0, 1)
plt.ylim(0, 1)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
for i in range(self.modelnum):
plt.plot(self.xdata[i], self.ydata[i], color=self.colorbar[i%len(self.colorbar)], linewidth=2.0, linestyle=self.linestyle[i%len(self.linestyle)],label=self.modelName[i]+',MAP:' + str(self.auc[i]))
plt.legend()
plt.savefig(self.savepath+'%s_PR.png'%self.dataName, dpi=800)
#plt.show()
def fileList(imgpath,filetype):
return glob.glob(imgpath+filetype)
def drawCurve(gtlist,problist,modelName,dataset,savepath='./'):
"""
draw ROC PR curve,calculate AUC MAP IoU
:param gtlist: list,groundtruth list
:param problist: list,list of probgraph list
:param modelName: list,name of test,model
:param dataset: string,name of dataset
:param savepath: string,path to save curve
:return:
"""
assert(len(problist)==len(modelName))
process = CollectData()
painter_roc = DrawCurve(savepath)
painter_pr = DrawCurve(savepath)
modelNum=len(problist)
painter_roc.newly(modelNum)
painter_pr.newly(modelNum)
# calculate param
for index in range(modelNum):
process.reload(gtlist,problist[index])
(FPR, TPR, AUC), (Precision, Recall, MAP),IoU = process.statistics()
painter_roc.reload(FPR, TPR, AUC,dataset, modelName[index])
painter_pr.reload(Precision, Recall, MAP, dataset, modelName[index])
# draw curve and save
painter_roc.roc()
painter_pr.pr()
if __name__=="__main__":
gtlist = fileList('./gt/', '*.png')
problist1 = fileList('./pre1/', '*.png')
problist2 = fileList('./pre2/', '*.png')
modelName=["fcn","unet"]
drawCurve(gtlist,[problist1,problist2],modelName,'kaggle')
print('--------------------------------------')