机器学习实战-系列教程6:SVM分类实战1之简单SVM分类(鸢尾花数据集/软间隔/线性SVM/非线性SVM/scikit-learn框架)项目实战、原理解读、代码解读

机器学习 实战系列 总目录

本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传

SVM分类实战1之简单SVM分类
SVM分类实战2线性SVM
SVM分类实战3非线性SVM

支持向量机(Support Vector Machines,SVM),用于分类和回归任务的强大模型,通过找到最佳的超平面分离不同类别

  • 与传统算法进行对比,看看SVM究竟能带来什么样的效果
  • 软间隔的作用,这么复杂的算法肯定会导致过拟合现象,如何来进行解决呢?
  • 核函数的作用,如果只是做线性分类,好像轮不到SVM登场了,核函数才是它的强大之处!

1、初始操作

1.1 导包与设置

import numpy as np
import os
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
import warnings
warnings.filterwarnings('ignore')
  1. plt.rcParams['axes.labelsize'] = 14: 这行代码设置了matplotlib图形中坐标轴标签的字体大小为14。
  2. plt.rcParams['xtick.labelsize'] = 12: 这行代码设置了matplotlib图形中x轴刻度标签的字体大小为12。
  3. plt.rcParams['ytick.labelsize'] = 12: 这行代码设置了matplotlib图形中y轴刻度标签的字体大小为12。
  4. warnings.filterwarnings('ignore'): 这行代码将警告消息的输出设置为忽略(不显示),这意味着在后续的代码执行中,如果有警告消息产生,它们将不会在屏幕上显示。

1.2 读入数据

from sklearn.svm import SVC
from sklearn import datasets
iris = datasets.load_iris()
X = iris['data'][:, (2, 3)]
y = iris['target']
setosa_or_versicolor = (y == 0) | (y == 1)
X = X[setosa_or_versicolor]
y = y[setosa_or_versicolor]
svm_clf = SVC(kernel='linear', C=1e10)
svm_clf.fit(X, y)
  1. 读入鸢尾花分类数据集
  2. 指定数据
  3. 指定标签
  4. setosa_or_versicolor = (y == 0) | (y == 1): 这行代码创建了一个布尔掩码数组setosa_or_versicolor,用于选择类别为0(山鸢尾)或1(变色鸢尾)的样本。
  5. X = X[setosa_or_versicolor]: 这行代码根据布尔掩码选择了对应的特征数据,只保留了山鸢尾和变色鸢尾的数据。
  6. y = y[setosa_or_versicolor]: 这行代码根据布尔掩码选择了对应的目标类别,只保留了山鸢尾和变色鸢尾的类别标签。
  7. svm_clf = SVC(kernel='linear', C=1e10): 这行代码创建了一个SVM分类器的实例svm_clf,使用线性核函数(kernel=‘linear’)和一个非常大的正则化参数C(C=1e10),这将导致模型尽量不允许分类错误。
  8. svm_clf.fit(X, y): 这行代码用选定的特征和目标数据训练了SVM分类器,以学习如何分辨山鸢尾和变色鸢尾。

在训练完成后,svm_clf模型将能够对新的鸢尾花样本进行分类,并决定它们属于山鸢尾(类别0)还是变色鸢尾(类别1)。这是一个简单的二元分类示例,演示了如何使用SVM来处理数据集中的两个类别。

2、简单SVM分类实例

2.1 线性回归模型

先看看一般的模型是怎么进行一个分类任务的,画出回归线隔开散点图对其分类,再用svm画出分割线最对比,看最终的效果怎么样。

x0 = np.linspace(0, 5.5, 200)
pred_1 = 5*x0 - 20
pred_2 = x0 - 1.8
pred_3 = 0.1 * x0 + 0.5
  1. 定义输入数据,200个0到5.5之间均匀分布作为输入数据
  2. 创建一般的模型的预测线,预测线1,w = 5,b=-20
  3. 预测线2
  4. 预测线3

2.2 简单SVM分类模型

def plot_svc_decision_boundary(svm_clf, xmin, xmax,sv=True):
    w = svm_clf.coef_[0]
    b = svm_clf.intercept_[0]
    print (w)
    x0 = np.linspace(xmin, xmax, 200)
    decision_boundary = - w[0]/w[1] * x0 - b/w[1]
    margin = 1/w[1]
    gutter_up = decision_boundary + margin
    gutter_down = decision_boundary - margin
    if sv:
        svs = svm_clf.support_vectors_
        plt.scatter(svs[:,0],svs[:,1],s=180,facecolors='#FFAAAA')
    plt.plot(x0,decision_boundary,'k-',linewidth=2)
    plt.plot(x0,gutter_up,'k--',linewidth=2)
    plt.plot(x0,gutter_down,'k--',linewidth=2)

前面我们已经定义了一个SVM分类器的实例svm_clf,使用线性核函数(kernel='linear'),表示回归方程就是一个 y = w 0 x 0 + w 1 x 1 + b y = w_0x_0+w_1x_1+b y=w0x0+w1x1+b的形式。

  1. 定义一个函数来绘制SVM分类器的决策边界和间隔,传入分类器、输入最大值、输入最小值、函数是否绘制支持向量
  2. 获取SVM分类器的权重
  3. 获取SVM分类器的截距
  4. 打印权重
  5. 一样的 x 0 x_0 x0
  6. 计算决策边界
  7. 计算间隔
  8. 计算上边界
  9. 计算下边界
  10. 当 sv 被设置为 True 时,函数会绘制支持向量(Support Vectors)
  11. 获取支持向量
  12. 绘制支持向量的散点图
  13. 绘制决策边界
  14. 绘制上边界
  15. 绘制下边界

2.3 画图对比

plt.figure(figsize=(14,4))
plt.subplot(121)
plt.plot(X[:,0][y==1],X[:,1][y==1],'bs')
plt.plot(X[:,0][y==0],X[:,1][y==0],'ys')
plt.plot(x0,pred_1,'g--',linewidth=2)
plt.plot(x0,pred_2,'m-',linewidth=2)
plt.plot(x0,pred_3,'r-',linewidth=2)
plt.axis([0,5.5,0,2])

plt.subplot(122)
plot_svc_decision_boundary(svm_clf, 0, 5.5)
plt.plot(X[:,0][y==1],X[:,1][y==1],'bs')
plt.plot(X[:,0][y==0],X[:,1][y==0],'ys')
plt.axis([0,5.5,0,2])
  1. 创建一个图形窗口
  2. 在图形窗口中创建两个子图
  3. 绘制数据点,其中类别为1的用蓝色方块表示,类别为0的用黄色方块表示
  4. 绘制一般的模型预测线
  5. 设置坐标轴范围
  6. 在图形窗口中创建第二个子图
  7. 绘制SVM分类器的决策边界和间隔
  8. 绘制数据点,其中类别为1的用蓝色方块表示,类别为0的用黄色方块表示
  9. 设置坐标轴范围
  10. 显示图形
    打印结果:

[1.29411744 0.82352928]
(0.0, 5.5, 0.0, 2.0)

机器学习实战-系列教程6:SVM分类实战1之简单SVM分类(鸢尾花数据集/软间隔/线性SVM/非线性SVM/scikit-learn框架)项目实战、原理解读、代码解读_第1张图片

SVM分类实战1之简单SVM分类
SVM分类实战2线性SVM
SVM分类实战3非线性SVM

你可能感兴趣的:(机器学习实战,机器学习,sklearn,人工智能,支持向量机,分类,回归)