目录
一、什么是SVM
二、最大间隔与分类
三、对偶问题
一、等式约束
二、不等式约束的KKT条件
三、KKT
四、SMO高效优化算法
五、通过SMO-SVM实现对莺尾花数据集的二分类
六、总结
SVM是一种监督机器学习算法,是一种二分类模型,它的目的是寻找一个超平面来对样本进行分割,分割的原则是间隔最大化,最终转化为一个凸二次规划问题来求解。可用于分类或回归挑战。然而,它主要用于分类问题。
如果一个线性函数能够将样本分开,称这些数据样本是线性可分的。
那么什么是线性函数呢?其实很简单,在二维空间中就是一条直线,在三维空间中就是一个平面,以此类推,如果不考虑空间维数,这样的线性函数统称为超平面。
我们看一个简单的二维空间的例子,+代表正类,-代表负类,样本是线性可分的,但是很显然不只有这一条直线可以将样本分开,而是有无数条,我们所说的线性可分支持向量机就对应着能将数据正确划分并且间隔最大的直线。
在样本空间中寻找一个超平面, 将不同类别的样本分开。
最大化间隔: 寻找参数w和b , 使得下述公式最大:
给定一个目标函数 f : Rn→R,希望找到xRn,在满足约束条件g(x)=0的前提下,使得f(x)有最小值。该约束优化问题记为:
分别对待求解参数求偏导,可得:
一般联立方程组可以得到相应的解。
SMO算法的目标是求出一系列alpha和b,一旦求出这些alpha,就很容易算出权重向量w并得到分割超平面。
SMO算法工作原理:每次循环中选择两个alpha进行优化处理。一旦找到一对合适的alpha,那么就增大其中一个,减小另外一个。
算法流程:每次选取两个a进行更新
svm.py
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# 鸢尾花(iris)数据集
# 数据集内包含 3 类共 150 条记录,每类各 50 个数据,
# 每条记录都有 4 项特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度,
# 可以通过这4个特征预测鸢尾花卉属于(iris-setosa, iris-versicolour, iris-virginica)中的哪一品种。
# 这里只取前100条记录,两项特征,两个类别。
def create_data():
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']
data = np.array(df.iloc[:100, [0, 1, -1]])
for i in range(len(data)):
if data[i,-1] == 0:
data[i,-1] = -1
return data[:,:2], data[:,-1]
#使用RBF(Radial basis function)核函数处理
def K(x,z,sigma=1.5):
return np.exp(np.dot((x-z),(x-z).T)/-(2*sigma**2))
#对应课本147页的g(x_i),该函数助于验证KKT条件
def g(i,x,y,alpha,b):
sum=b
for j in range(len(y)):
sum+=alpha[j]*y[j]*K(x[i],x[j])
return sum
#验证第i个样本点是否满足KKT条件
def isKKT(alpha,i,x,y,b,C):
if alpha[i]==0 and y[i]*g(i,x,y,alpha,b)>=1:
return True
elif alpha[i]==C and y[i]*g(i,x,y,alpha,b)<=1:
return True
elif alpha[i]>0 and alpha[i]
main.py
from svm import *
from numpy import *
from random import *
from matplotlib import pyplot as plt
def train(C=1.0):
#获得数据集
x,y=create_data()
#设定迭代次数为100次
iter=100
#样本容量也就是标签的个数
N=len(y)
#alpha的初始值取全0
alpha=zeros(len(y))
#设置i,j的初始值(对应alpha1和alpha2)
i,j=randint(0,N-1),randint(0,N-1)
#保证i≠j
while i==j:
i=randint(0,N-1)
for k in range(iter):
#x的尺寸为一个1×2行向量
x_i,x_j=x[i],x[j]
#y的取值为+1或-1
y_i,y_j=y[i],y[j]
#计算ita,为计算a2_newunc做准备
ita=K(x_i,x_i)+K(x_j,x_j)-2*K(x_i,x_j)
if ita==0:
continue
#计算分割平面参数w与b
#x:100×2矩阵,w:1×2矩阵
#由于y-dot(w,x.T)是个与y等长的行向量,取其各元素平均值
w=dot(alpha*y,x)
b=mean(y-dot(w,x.T))
#计算误差E1和E2
E_i=E(w,b,x_i,y_i)
E_j=E(w,b,x_j,y_j)
#计算a2_ewunc
a1_old=alpha[i]
a2_old=alpha[j]
a2_newunc=a2_old+y_j*(E_i-E_j)/ita
#计算L与H
L,H=0.0,0.0
if y_i!=y_j:
L=max(0,a2_old-a1_old)
H=min(C,C+a2_old-a1_old)
elif y_i==y_j:
L=max(0,a2_old+a1_old-C)
H=min(C,a2_old+a1_old)
#计算剪辑后a2_new与a1_new的值
a2_new=max(L,min(H,a2_newunc))
a1_new=a1_old+y_i*y_j*(a2_old-a2_new)
#更新alpha
alpha[i],alpha[j]=a1_new,a2_new
#violation表示每个元素违反KKT条件的程度
violation=zeros(N)
#对每一个样本点检验KKT条件,在violation内记录每个样本点违反KKT的程度
for k in range(N):
if isKKT(alpha,k,x,y,b,C)==False:
violation[k]=float(vioKKT(alpha,k,x,y,b))
#如果没有违反KKT条件,则违反程度是0
else:
violation[k]=0.0
#找到violation中违反程度最大的点,设定为i,对应alpha_1
i=findindex(violation,max(violation))
#这里设置j(对应alpha_2)为不等于i的随机数。
#原本alpha_2的选取应该是令abs(E_i-E_k)最大的k值对应的alpha点
#经过测试,在大多数情况下,abs(E_i-E_k)(1×100向量)的所有元素都是0
#即预测每个元素都准确,每个元素的分类误差都是0,误差的差值也是0
#只有少数情况下,会有一个误差差值不等于0
#对于前一种情况,无所谓“最大的误差差值”(因为都是0),因此只能设置j为随机数
#对于后一种情况,由于出现的次数少,并且那一个不为0的差值的元素出现的位置具有随机性
#因此总是将j设定为随机数
j=randint(0,N-1)
while j==i:
j = randint(0, N - 1)
#计算最终(迭代100次)分割平面参数
w = dot(alpha * y, x)
b = mean(y - dot(w, x.T))
draw_x, draw_y, draw_label = [], [], []
#在散点图上标记样本点的位置,样本点第一个元素作为x坐标,第二个元素作为y坐标
for p in x:
draw_x.append(p[0])
draw_y.append(p[1])
#画散点图,其中支持向量呈现绿色,正类呈现红色,负类呈现蓝色
#样本点离分割直线最近的为支持向量
distance=zeros(len(y))
for i in range(len(y)):
distance[i]=distance_count(x[i],w,b)
vector=findindex(distance,min(distance))
for i in range(len(y)):
if i==vector:
draw_label.append('g')
else:
if y[i] > 0:
draw_label.append('r')
else:
draw_label.append('b')
plt.scatter(draw_x, draw_y, color=draw_label)
plain_x = range(4, 8, 1)
plain_y = []
for i in plain_x:
temp = double(-(w[0] * i + b) / w[1])
plain_y.append(temp)
plt.plot(plain_x, plain_y)
#最终绘图
plt.savefig('SMO.jpg')
plt.show()
if __name__ == '__main__':
train()
运行结果截图:
SVM的优点:
1、解决小样本下机器学习问题。
2、解决非线性问题。
3、无局部极小值问题。(相对于神经网络等算法)
4、可以很好的处理高维数据集。
5、泛化能力比较强。
SVM的缺点:
1、对于核函数的高维映射解释力不强,尤其是径向基函数。
2、对缺失数据敏感。
3、调参麻烦,核函数难选难调