用svm对数据进行二分类(完整代码)

一、前言

这是笔者学习opencv中svm的一个小例子,数据集是用sk-learn库中的函数生成的。功能就是对该数据进行二分类。
为了学习的更深入,笔者将svm常用的四种核,linear,inter,sigmoid,rbf作了对比。
详细步骤见代码

二、代码

import numpy as np
from cv2 import cv2
from sklearn import datasets
import matplotlib.pyplot as plt
from sklearn import model_selection as ms
from sklearn import metrics


def plot_decision_boundary(svm,x_text,y_text):
    '''
    函数功能,将svm训练出的结果可视化函数,参数说明,
    svm:训练好的svm模型
    x_text:测试数据集
    y_text:测试数据集标签
    '''
    #  确定x轴上的左右边界
    x_min,x_max = x_text[:,0].min()-1,x_text[:,0].max()+1
    #  确定y轴上的左右边界
    y_min,y_max = x_text[:,1].min()-1,x_text[:,1].max()+1
    #  创建合适的网格,h是步长
    h =0.02  
    xx,yy = np.meshgrid(np.arange(x_min,x_max,h),
                        np.arange(y_min,y_max,h))
    #  np.c_功能,按行连接两个矩阵,要求行数一样。
    x_hypo = np.c_[xx.ravel().astype(np.float32),
                   yy.ravel().astype(np.float32)]
    #  开始预测
    _,zz = svm.predict(x_hypo)
    #  大小要一样
    zz = zz.reshape(xx.shape)
    #  绘制三维等高线图,必须在网格结构中才可以
    plt.contourf(xx,yy,zz,cmap=plt.cm.coolwarm,alpha=0.8)
    #  按标签绘制散点图
    plt.scatter(x_text[:,0],x_text[:,1],c=y_text,s=100)




if __name__ == '__main__':
    #  使用自带函数创建一个含有100个样本,特征两个,两个标签的数据样本  
    x,y = datasets.make_classification(n_samples=100,n_features=2,n_classes=2,n_redundant=0,random_state=7816)
    x = x.astype(np.float32)
    y = y*2-1
#  定义一个含各种svm核的列表
    kernels = [cv2.ml.SVM_LINEAR,cv2.ml.SVM_INTER,
           cv2.ml.SVM_SIGMOID,cv2.ml.SVM_RBF]
    #  分割数据集,20%为测试集
    x_train,x_text,y_train,y_text = ms.train_test_split(x,y,test_size=0.2,random_state=42)
    #  训练含有不同核的svm分类器
    for i,kernel in enumerate(kernels):
        svm = cv2.ml.SVM_create()
        #  设置svm核
        svm.setKernel(kernel)
        svm.train(x_train,cv2.ml.ROW_SAMPLE,y_train)
        b,y_pred = svm.predict(x_text)
        #  用测试集计算准确率
        a=metrics.accuracy_score(y_text,y_pred)
        print(a)
        #  创建4个子图
        plt.subplot(2,2,i+1)
        #  可视化结果
        plot_decision_boundary(svm,x_text,y_text)
        plt.title('accuracy: %.2f' %a)
    plt.show()

三、结果

用svm对数据进行二分类(完整代码)_第1张图片
从图中可以看到rbf核准确率最高,rbf核更适合非线性分类

你可能感兴趣的:(opencv中的机器学习,svm,python)