毕业设计做到最后光有准确率不行,还非得要个图表,不然没有说服力,SVM又没有学习曲线,那就画个ROC曲线。直接写了一套函数,也方便比较不同模型哪个更好一些。
ROC曲线横坐标为假阳性率/假正例率/误报率FPR(False Positive Rate),纵坐标为真阳性率/真正例率/正样本率TPR(True Positive Rate)。
模型在对样本进行预测之后,一般来说分类的正确率是无法达到100%。由此我们就能得到:
一个样本通过“实际是哪一类”和“被模型分类到哪一类”这两个属性可以确定其属于TP、FN、FP还是TN。而所谓TPR(True Positive Rate)也就是TP(True Positive)在P(Positive)中所占的比例(Rate)。注意P指的是实际正样本的个数,即P = TP + FN。真正例率:在预测为“真的”的一群样本中,其中预测正确的概率。
FPR = FP/N,这里的N(Negative)指的是实际负样本的个数。误报率这个名字可能好理解一些,原本是“假的”却被判断成“真的”的概率。
刚接触这些很容易混淆,假设实际样本的“正负”表示的是警察局接到的报警电话是真的还是假的,预测的样本正负指代接线员判断这个报警电话是真是假,并由此来决定警察是否出警。那么TPR就是警察正确判断真的报警电话并出警的次数/总计真的报警电话的数量。也就是说TPR的值越高,说明警察正确判断越多,接线员的判断越准确,工作更有效。而FPR则是接线员判断错误,警察白费力气出警的次数/假的报警电话的数量,也就是说FPR的值越大,警察做的无用功越多,接线员判断的越不准确。
也就是说,我们追求的模型应该具有低FPR和高TPR,也就是接近左上角的位置
因为ROC曲线的作用是来比较不同模型的训练效果,越接近左上角,曲线积分面积越大,分类的效果也就越好。
下面以五个样本为例说明如何将n个样本映射成n个ROC曲线上的点:
假设有5个样本,它们的实际类标签分别为[-1,-1,1,1,1],它们的决策值分别为[0.3,0.4,0.5,0.6,0.7]。
第一次选择0.3作为阈值,小于等于0.3的分类为-1,大于0.3的分类为1。分类后的类别标签为[-1,1,1,1,1],(得到预测后的类别标签与实际的类标签就能得到FP和TP了,也就意味着可以根据分类后的类别标签和实际的类别标签来计算TPR和FPR了)
判断正确的正样本数量为3,实际的正样本数量为3。判断错误的负样本数量为1,实际的负样本数量为2
FPR=1/2=0.5,TPR=3/3=1
第二次选择0.4作为阈值,分类后的标签为[-1,-1,1,1,1],判断正确的正样本数量为3,判断错误的负样本数量为0。
FPR=0,TPR=3/3=1。
第三次选择0.5为阈值,分类后标签为[-1,-1,-1,1,1] ,判断正确的正样本数量为2,判断错误的负样本数量为0。
FPR=0,TPR=2/3=0.67。
第四次选择0.5为阈值,分类后标签为[-1,-1,-1,-1,1] ,判断正确的正样本数量为1,判断错误的负样本数量为0。
FPR=0,TPR=1/3=0.33。
第五次选择0.5为阈值,分类后标签为[-1,-1,-1,-1,-1] ,判断正确的正样本数量为0,判断错误的负样本数量为0。
FPR=0,TPR=0。
于是由5个样本获得了5个ROC的坐标(0.5,1),(0,1),(0,0.67),(0,0.33),(0,0)。由这5个坐标绘制出该模型的ROC曲线。
y=x即TPR = FPR,如果ROC曲线y=x附近,说明这个模型分类的效果非常差(五五开)。通常会使用AUC值来辅助ROC曲线,AUC计算的是曲线下方的面积大小,也就意味着AUC值越大,ROC曲线越接近左上角(0,1),分类的效果也就越好,因为不同模型对同一组数据通常有不同的决策值,也就有不同的ROC曲线,可以用于比较两个不同模型训练效果。不过我都没绘制曲线,AUC也就没算。
绘制ROC曲线时,使用的是一组测试数据,这些数据实际的类别我们是知道的,通过模型对这些数据进行预测之后,我们能得到模型对每个样本的一个决策值,在libsvm中:
p_label, p_acc, p_val = svm.svm_predict(...)
p_label是模型对样本预测的分类,p_acc是准确率,p_val则是决策值,在绘制ROC曲线时我们要用到p_label和p_val,在模型对数据集预测完之后将其写入到ROC.txt中去,格式像下面这样,每一行对应一个样本,两个数值分别是样本的实际分类和决策值。
然后通过这两个数据,计算出ROC曲线上的坐标。将ROC.txt作为参数file_path传入。通过前面的概念可以知道,ROC曲线上的每一个坐标都是一对(FPR,TPR)。我们每次将一个样本的决策值作为分类的阈值,一旦有了阈值,也就能对所有的样本进行分类,而我们也预先知道每个样本的实际分类。于是就能得到TPR和FPR。将所有数据的决策值都轮流用作一次阈值,n个样本就能得到n个ROC曲线上的坐标。
#用于ROC坐标的排序
def takeFirst(elem):
return elem[0]
def createROC(file_path,save_path):
#读取txt文件,将其中的实际label和决策值p_val读取出来
label = []
p_val = []
index = 0
#实际正样本数量
count_P = 0
#实际负样本数量
count_N = 0
with open(file_path, 'r') as f:
for line in f.readlines():
# 去掉末尾换行符
line = line.strip()
# 每一行为一列的最大值与最小值
#l - label和v - val
l,v = line.split(" ",1)
label.append(float(l))
p_val.append(float(v))
if (label[index]>0):
count_P += 1
else:
count_N += 1
index = index + 1
#读取完毕,获得了index个样本的标签和决策值
#实际正样本数量:count_P,实际负样本数量:count_N
#存储TPR和FPR
pos = []
#然后依次将这些决策值作为阈值
for value in p_val:
#当前的阈值为value
#要根据当前阈值,获得分类正确的正样本数量 TP
#分类错误的负样本数量 FP
TP = 0
FP = 0
index = 0
#每个决策值依次作为分类的阈值
for i in p_val:
if(i<=value):
label_now = -1.0
else:
label_now = 1.0
#如果是分类正确的正样本
if(label_now>0 and label[index]>0):
TP += 1
#如果是分类错误的负样本
if(label_now>0 and label[index]<0):
FP += 1
#计数增加
index += 1
#计算FPR和TPR
FPR = FP/count_N
TPR = TP/count_P
pos.append((FPR,TPR))
#按横坐标进行排序 (默认升序),为了之后的绘图(如果使用Matlab的话不需要这一行)
pos.sort(key=takeFirst)
#写入文件保存起来
for x,y in pos:
pos_str = str(x) + " " + str(y) + "\n"
#保存
get_f.writeindata(save_path, pos_str)
return
有了点的坐标,就可以使用matlab拟合出曲线了,但我当时做到这里偷懒了,为了能快速出结果,同时样本的数量有3000个足够多,就直接用折线来近似曲线了。
注意:绘制折线的时候要将坐标点按横坐标从小到大排序然后依次绘制。
def drawROC(path):
pos_x = []
pos_y = []
with open(path, 'r') as f:
for line in f.readlines():
# 去掉末尾换行符
line = line.strip()
# 每一行为一列的最大值与最小值
roc_x,roc_y = line.split(" ",1)
#print(roc_x+","+roc_y)
pos_x.append(float(roc_x))
pos_y.append(float(roc_y))
# 绘制ROC曲线
# 设置画图窗口大小
plt.figure(figsize=(5, 5))
# 添加标题
plt.title("ROC Curve", fontsize='15')
#
plt.plot([0,1], [1,1], color='black', ls='--', lw=1, label='top')
plt.plot([0, 1], [0, 1], color='blue', ls='--', lw=1, label='TPR=FPR')
#绘制ROC
plt.plot(pos_x,pos_y,color='red',ls='-',lw=2,label='roc line')
plt.xlim((-0.01, 1.01)) # 设置x轴最大最小值 偏移了一点是为了让曲线展示得更清晰
plt.ylim((0, 1.05)) # 设置y轴最大最小值
plt.xlabel('FPR') # 添加x轴图标
plt.ylabel('TPR') # 添加y轴图标
plt.legend() # 添加图例
#将图片保存起来
plt.savefig('E://data/424/ROC.png')
plt.show()
return