本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
SVM分类实战1之简单SVM分类
SVM分类实战2线性SVM
SVM分类实战3非线性SVM
支持向量机(Support Vector Machines,SVM),用于分类和回归任务的强大模型,通过找到最佳的超平面分离不同类别
import numpy as np
import os
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
import warnings
warnings.filterwarnings('ignore')
plt.rcParams['axes.labelsize'] = 14
: 这行代码设置了matplotlib图形中坐标轴标签的字体大小为14。plt.rcParams['xtick.labelsize'] = 12
: 这行代码设置了matplotlib图形中x轴刻度标签的字体大小为12。plt.rcParams['ytick.labelsize'] = 12
: 这行代码设置了matplotlib图形中y轴刻度标签的字体大小为12。warnings.filterwarnings('ignore')
: 这行代码将警告消息的输出设置为忽略(不显示),这意味着在后续的代码执行中,如果有警告消息产生,它们将不会在屏幕上显示。from sklearn.svm import SVC
from sklearn import datasets
iris = datasets.load_iris()
X = iris['data'][:, (2, 3)]
y = iris['target']
setosa_or_versicolor = (y == 0) | (y == 1)
X = X[setosa_or_versicolor]
y = y[setosa_or_versicolor]
svm_clf = SVC(kernel='linear', C=1e10)
svm_clf.fit(X, y)
setosa_or_versicolor = (y == 0) | (y == 1)
: 这行代码创建了一个布尔掩码数组setosa_or_versicolor
,用于选择类别为0(山鸢尾)或1(变色鸢尾)的样本。X = X[setosa_or_versicolor]
: 这行代码根据布尔掩码选择了对应的特征数据,只保留了山鸢尾和变色鸢尾的数据。y = y[setosa_or_versicolor]
: 这行代码根据布尔掩码选择了对应的目标类别,只保留了山鸢尾和变色鸢尾的类别标签。svm_clf = SVC(kernel='linear', C=1e10)
: 这行代码创建了一个SVM分类器的实例svm_clf
,使用线性核函数(kernel=‘linear’)和一个非常大的正则化参数C(C=1e10),这将导致模型尽量不允许分类错误。svm_clf.fit(X, y)
: 这行代码用选定的特征和目标数据训练了SVM分类器,以学习如何分辨山鸢尾和变色鸢尾。在训练完成后,svm_clf
模型将能够对新的鸢尾花样本进行分类,并决定它们属于山鸢尾(类别0)还是变色鸢尾(类别1)。这是一个简单的二元分类示例,演示了如何使用SVM来处理数据集中的两个类别。
先看看一般的模型是怎么进行一个分类任务的,画出回归线隔开散点图对其分类,再用svm画出分割线最对比,看最终的效果怎么样。
x0 = np.linspace(0, 5.5, 200)
pred_1 = 5*x0 - 20
pred_2 = x0 - 1.8
pred_3 = 0.1 * x0 + 0.5
def plot_svc_decision_boundary(svm_clf, xmin, xmax,sv=True):
w = svm_clf.coef_[0]
b = svm_clf.intercept_[0]
print (w)
x0 = np.linspace(xmin, xmax, 200)
decision_boundary = - w[0]/w[1] * x0 - b/w[1]
margin = 1/w[1]
gutter_up = decision_boundary + margin
gutter_down = decision_boundary - margin
if sv:
svs = svm_clf.support_vectors_
plt.scatter(svs[:,0],svs[:,1],s=180,facecolors='#FFAAAA')
plt.plot(x0,decision_boundary,'k-',linewidth=2)
plt.plot(x0,gutter_up,'k--',linewidth=2)
plt.plot(x0,gutter_down,'k--',linewidth=2)
前面我们已经定义了一个SVM分类器的实例svm_clf
,使用线性核函数(kernel='linear')
,表示回归方程就是一个 y = w 0 x 0 + w 1 x 1 + b y = w_0x_0+w_1x_1+b y=w0x0+w1x1+b的形式。
plt.figure(figsize=(14,4))
plt.subplot(121)
plt.plot(X[:,0][y==1],X[:,1][y==1],'bs')
plt.plot(X[:,0][y==0],X[:,1][y==0],'ys')
plt.plot(x0,pred_1,'g--',linewidth=2)
plt.plot(x0,pred_2,'m-',linewidth=2)
plt.plot(x0,pred_3,'r-',linewidth=2)
plt.axis([0,5.5,0,2])
plt.subplot(122)
plot_svc_decision_boundary(svm_clf, 0, 5.5)
plt.plot(X[:,0][y==1],X[:,1][y==1],'bs')
plt.plot(X[:,0][y==0],X[:,1][y==0],'ys')
plt.axis([0,5.5,0,2])
[1.29411744 0.82352928]
(0.0, 5.5, 0.0, 2.0)
SVM分类实战1之简单SVM分类
SVM分类实战2线性SVM
SVM分类实战3非线性SVM