SVM之软间隔最大化

跟我一起机器学习系列文章将首发于公众号:月来客栈,欢迎文末扫码关注!

在前面几篇文章中,笔者分别介绍了什么是支持向量机以及如何通过sklearn来完成一个简单的SVM建模;接着还介绍了什么是线性不可分与核函数。在接下来的这篇文章中,笔者将继续介绍SVM中的软间隔与sklearn相关SVM模型的实现。

1 什么是软间隔

我们之前谈到过两种情况下的分类:一种是直接线性可分的;另外一种是通过 ϕ ( x ) \phi(x) ϕ(x)映射到高维空间之后“线性可分”的。为什么后面这个“线性可分”要加上引号呢?这是因为在 上一篇文章 中有一件事没有和大家交代:虽然通过映射到高维空间的方式能够很大程度上使得原先线性不可分的数据集线性可分,但是我们并不能够一定保证它就是线性可分的,可能这个高维空间依旧线性不可分得换一个(事实上你还是不知道换哪一个更好,所以此时就要折中选择),或者保守的说即使线性可分了,但也可能会有过拟合现象。这是因为超平面对于异常点(outlier)过于敏感。如下图:

SVM之软间隔最大化_第1张图片

如上图(a)所示,实线为该数据集下的最优决策面。但是如果在测试集中出现一个异常点,如图(b)所示,那它将导致分类直线发生剧烈的摆动,虽然最终也达到了将数据集分开的效果,但这显然不是我们希望的。我们将图(a)和(b)中的情况称之为硬间隔(hard margin),即不允许任何样本出现错分的情况,哪怕导致过拟合。而我们所期望的应该是图©中的这种情况:容许少量样本被错分,从而得到一个次优解,而这个容忍的程度则通过目标函数来调节。或者再极端一点就是,根本找不到一个超平面能够将样本无误的分开(不过拟合的前提下),必须得错分一些点。此时图©中虚线与实线之间的间隔就称之为软间隔(soft margin)

2 软间隔最大化

此时我们可以知道,如数据集中出现了异常点,必将导致该异常点的函数间隔小于1。所以,此时引入一个松弛变量( ξ > 0 \xi>0 ξ>0),使得函数加上松弛变量大于等于1.
y ( i ) ( w T x ( i ) + b ) ≥ 1 − ξ i (1) y^{(i)}(w^Tx^{(i)}+b)\geq1-\xi_i\tag {1} y(i)(wTx(i)+b)1ξi(1)

那么此时的目标函数可以重新改写为如下形式:
min ⁡ w , b , ξ 1 2 ∣ ∣ w ∣ ∣ 2 + C ∑ i = 1 m ξ i s . t .      y ( i ) ( w T x ( i ) + b ) ≥ 1 − ξ i , i = 1 , 2 , . . . m ξ i ≥ 0 , i = 1 , 2 , . . . m (2) \begin{aligned} \min_{w,b,\xi} &\frac{1}{2}{||w||^2}+C\sum_{i=1}^m\xi_i\\[1ex] s.t.\;\;&y^{(i)}\large(w^Tx^{(i)}+b)\geq1-\xi_i,i=1,2,...m\\[1ex] &\xi_i\geq0,i=1,2,...m \end{aligned}\tag{2} w,b,ξmins.t.21w2+Ci=1mξiy(i)(wTx(i)+b)1ξi,i=1,2,...mξi0,i=1,2,...m(2)
其中 C > 0 C>0 C>0称为惩罚参数, C C C越大时对误分类的惩罚越大,其作用等同于正则化中的 λ \lambda λ 。最小化目标函数 ( 2 ) (2) (2)包含两层含义:使 1 2 ∣ ∣ w ∣ ∣ 2 \frac{1}{2}||w||^2 21w2尽量小,即分类间隔尽量大,同时使误分类点的个数尽量小, C C C是调和二者的系数。并且只要错分一个样本点,我们都将付出 C ξ i C\xi_i Cξi的代价。

3 示例

跟我一起机器学习系列文章将首发于公众号:月来客栈,欢迎文末扫码关注!

3.1 API介绍

在前面一篇文章中,我们大致列出了sklearn中SVM分类器的几个重要参数,如下所示:

def __init__(self, C=1.0, 
             kernel='rbf', 
             degree=3,):

其中 C C C就表示式子 ( 2 ) (2) (2)中的惩罚系数,它的作用就是用来控制容忍决策面错分的程度,越大则模型越偏向于过拟合。如下图所示为不同取值 C C C下的决策面(红色决策面中 C = 1 C=1 C=1,绿色决策面中 C = 1000 C=1000 C=1000):

SVM之软间隔最大化_第2张图片

参数kernel表示选择哪种核函数,当kernel='poly'时可以用参数degree来选择多项式的次数。但是通常情况下我们都会选择效果更好的高斯核来作为核函数,因此该参数用得比较少。

3.2 分类示例

下面我们采用网格搜索来选择一个最佳SVM分类器对数据集iris进行分类。从上面对API的介绍可知,SVC中我们需要用到的超参数有3个,其取值我们分别设为:'C': np.arange(1, 100, 5)'kernel': ['rbf', 'linear', 'poly']'degree': np.arange(1, 20, 2)由此我们就能得到 20 × 3 × 10 = 600 20\times3\times10=600 20×3×10=600个备选模型。同时,我们以5折交叉验证进行训练。

  • 模型选择

    def model_selection(x_train, y_train):
        model = SVC()
        paras = {'C': np.arange(1, 100, 5), 
                 'kernel': ['rbf', 'linear', 'poly'], 'degree': np.arange(1, 20, 2)}
        gs = GridSearchCV(model, paras, cv=5, verbose=1, n_jobs=-1)
        gs.fit(x_train, y_train)
        print('best score:', gs.best_score_)
        print('best parameters:', gs.best_params_)
    
    #输出结果:
    Fitting 5 folds for each of 600 candidates, totalling 3000 fits
    [Parallel(n_jobs=-1)]: Using backend LokyBackend with 96 concurrent workers.
    [Parallel(n_jobs=-1)]: Done   8 tasks      | elapsed:    1.3s
    [Parallel(n_jobs=-1)]: Done 258 tasks      | elapsed:    2.2s
    [Parallel(n_jobs=-1)]: Done 608 tasks      | elapsed:    3.6s
    [Parallel(n_jobs=-1)]: Done 1058 tasks      | elapsed:    5.4s
    [Parallel(n_jobs=-1)]: Done 1608 tasks      | elapsed:    7.7s
    [Parallel(n_jobs=-1)]: Done 2258 tasks      | elapsed:   10.5s
    [Parallel(n_jobs=-1)]: Done 3000 out of 3000 | elapsed:   13.3s finished
    best score: 0.9912413836716626
    best parameters: {'C': 6, 'degree': 1, 'kernel': 'rbf'}
    

    可以看出,当惩罚系数C=6,以及选取高斯核函数时对应的模型效果最好。(由于选取的是高斯核所以此时参数degree无效。

  • 训练与预测

    def train(x_train, x_test, y_train, y_test):
        model = SVC(C=6, kernel='rbf')
        model.fit(x_train, y_train)
        score = model.score(x_test, y_test)
        print("accuracy: ", score)
    
    #输出结果:
    accuracy:  0.9851851851851852
    

    可以看出,此时模型在测试集上的准确率为0.98。

4 总结

在这篇文章中,笔者首先介绍了什么是软间隔及其原理;然后以iris数据集分类为例,介绍了在sklearn中如何用网格搜索来寻找最佳的模型参数。到此,我们就完成了对于SVM算法第一阶段的学习,内容不少也有相应的难度。在接下来的三篇文章中,笔者将首先和大家一起回顾一下好久不见的拉格朗日乘数法;然后再介绍求解模型参数需要用到的对偶问题;最后是一部分的求解过程,也就是得到 w w w的解析式,用以理解核函数。遗憾的是对于完整的求解过程,笔者也没有弄得十分清楚,因此那一部分就没有再写出来。有兴趣的可以自行查找相关资料进行学习。本次内容就到此结束,感谢阅读!

若有任何疑问与见解,请发邮件至[email protected]并附上文章链接,青山不改,绿水长流,月来客栈见!

引用

[1]《统计机器学习(第二版)》李航,公众号回复“统计学习方法”即可获得电子版与讲义

[2] Andrew Ng. CS229. Note3 http://cs229.stanford.edu/notes/cs229-notes3.pdf

[3]示例代码 :示例代码:关注公众号回复“示例代码”即可直接获取!

近期文章

[1]原来这就是支持向量机

[2]从另外一个角度理解支持向量机

[3]SVM之sklearn建模与线性不可分

你可能感兴趣的:(跟我一起机器学习)