SVM的预测部分不一致的问题

问题引入

今天,在做基于统计的SVM文本分类到Bert的语言模型分类转换的时候,发现阈值从原来0.3飙到0.99
因此,考虑到svm的多分类和FC的softmax不同,做了些测试,看了一下SVM的处理

首先,先看结果,SVM内部矛盾了

#===这里演示一下由OVO,两条路,不矛盾===
#1.这里是predict函数结果 
[1]
#2.这是libsvm暴露的接口
[[0.35212698 0.26687814 0.15270139 0.11694623 0.11134726]]
#可以看出predict分类1,对应predict_proba的置信最大,没有矛盾

#===这里演示一下由OVO,两条路,矛盾===
#1.这里是predict函数结果 
[2]
#2.这是libsvm暴露的接口
[[0.19582204 0.14254546 0.16261094 0.23086777 0.26815379]]
#可以看出predict分类2,对应predict_proba的置信却不是最大,有矛盾

SVM两条路

原因是什么呢?是因为sklearn这里有两条路,策略不同,而数据量少,两者不严格对应,因为svm的置信度,也就是这里的概率计算复杂;

路一:sklearn的预测到底凭的啥

Created with Raphaël 2.2.0 libsvm包OVO sklearn包_ovr_decision_function接口的OVR(投票) sklearn包decision_function函数 sklearn包predict函数里面argmax predict函数预测结果

在最开始,先贴一段话
SVM的预测部分不一致的问题_第1张图片
libsvm包,为了简化多分类,虽然暴露的是OVR和OVO,底层都是OVO

首先,进入predict函数
SVM的预测部分不一致的问题_第2张图片
从592看到是取了一个最大的数据,什么的最大?
SVM的预测部分不一致的问题_第3张图片
从这个函数实现可以看出,是投票+置信修正。
所以:预测的思路是:OVO然后转化OVR然后取大
策略在sklearn内部实现

路二:predict_proba到底拿到的是啥

Created with Raphaël 2.2.0 libsvm包OVO libsvm包multiclass_probability函数 predict_proba函数概率结果

这里拿到的是libSVM直接的东西
我们先贴一个https://www.csie.ntu.edu.tw/~cjlin/papers/libsvm.pdf文档第八章的内容
SVM的预测部分不一致的问题_第4张图片
我们看到,其实lib内部计算了,公式44,可以认为,由原有分布(两两比较,不是分布,因为和不是1)进行二次规划得到带有约束条件的新的分布
这块的代码在:https://github.com/arnaudsj/libsvm/blob/master/svm.cpp的1828行

// Method 2 from the multiclass_prob paper by Wu, Lin, and Weng
static void multiclass_probability(int k, double **r, double *p)
{
//。。。
}

总的测试代码:

# -*- coding: utf-8 -*-
"""
======================
@author:YuanYihan
@time:2020/5/12:16:43
@email:[email protected]
@phone:18192015917
======================
"""
from sklearn.svm import SVC


def main():
    X = [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13],
         [13, 14], [14, 15]]
    y = [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5]

    svc_clf = SVC(probability=True,
                  decision_function_shape='ovr',
                  break_ties=True, )
    svc_clf.fit(X, y)

    X_test = [[1, 2.5]]
    print("#===这里演示一下由OVO,两条路,不矛盾===")
    print("#1.这是libsvm的底层拿到的OVO")
    print(svc_clf._decision_function(X_test))  # 拿到底层OVO的结果
    #[[0.62205951 0.96867402 1.05203884 1.04267148 0.96880563 0.70934004 0.70606668 0.32225327 0.17647241 0.03762073]]
    print("#2.这是sklearn暴漏的接口,这里是OVO然后,经过OVR聚合,然后拿到最大值")
    print("#    2.1 这里是OVO")
    print(svc_clf._decision_function(X_test))  # 拿到底层OVO的结果
    print("#    2.2 这里是OVR")
    print(svc_clf.decision_function(X_test)) #OVR结果
    #[[ 4.26219102  3.21265452  1.80334849  0.77609939 -0.22082833]]
    print("#    2.3 这里是argmax ")
    print(svc_clf.predict(X_test))  # 底层是实现D:\Anaconda3\Lib\site-packages\sklearn\svm\_base.py(537行)
    #[1]
    print("#3.这是libsvm暴漏的接口")
    print(svc_clf.predict_proba(X_test))  # 直接libsvm的多分类结果
    #[[0.35212698 0.26687814 0.15270139 0.11694623 0.11134726]]
    #可以看出分类1,对应predict_proba的置信最大,没有矛盾

    X = [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5]]
    y = [1, 2, 3, 4, 5, ]
    svc_clf.fit(X, y)

    X_test = [[1, 2.5]]
    print("#===这里演示一下由OVO,两条路,矛盾:predict_proba和predict不一致,原因是数据量少===")
    print("#1.这是libsvm的底层拿到的OVO")
    print(svc_clf._decision_function(X_test))  # 拿到底层OVO的结果
    #[[-0.46028768 -0.27179334  0.23631958  0.45192563  0.18849434  0.69660726 0.91221332  0.50811292  0.72371898  0.21560606]]
    print("#2.这是sklearn暴漏的接口,这里是OVO然后,经过OVR聚合,然后拿到最大值")
    print("#    2.1 这里是OVO")
    print(svc_clf._decision_function(X_test))  # 拿到底层OVO的结果
    print("#    2.2 这里是OVR")
    print(svc_clf.decision_function(X_test))  # OVR结果
    #[[ 1.98600169  4.2310086   3.18935299  0.8164502  -0.23242915]]
    print("#    2.3 这里是argmax ")
    print(svc_clf.predict(X_test))  # 底层是实现D:\Anaconda3\Lib\site-packages\sklearn\svm\_base.py(537行)
    #[2]
    print("#3.这是libsvm暴漏的接口")
    print(svc_clf.predict_proba(X_test))  # 直接libsvm的多分类结果
    #[[0.19582204 0.14254546 0.16261094 0.23086777 0.26815379]]
    #可以看出分类2,对应predict_proba的置信却不是最大,有矛盾,这是因为策略不同,而数据量少,两者不严格对应


if __name__ == '__main__':
    main()

参考:
https://blog.csdn.net/tjcwt2011/article/details/80936672

后记

可见,svm处理过程,和softmax的简单是有所不同的。

你可能感兴趣的:(other)