https://blog.csdn.net/v_july_v/article/details/7624837
原理:就是找到一个超平面,将数据进行分类,
超平面
这是平面中的直线、空间中的平面之推广(n大于3才被称为“超”平面)
如果空间是3维的,那么它的超平面是二维平面,而如果空间是二维的,则其超平面是一维线。
sklearn提供了三种基于svm的分类方法:
sklearn.svm.NuSVC()
sklearn.svm.LinearSVC()
sklearn.svm.SVC()
# coding=utf-8 import pandas as pd import matplotlib.pyplot as plt from sklearn.model_selection import train_test_split from sklearn import datasets from sklearn import tree import numpy as np from sklearn.externals.six import StringIO import pydot from sklearn import svm from sklearn import datasets from sklearn.model_selection import train_test_split as ts def drawPoint(x,y,color,marker): plt.scatter(x, y, color=color, marker=marker) # #中心为0,四边为1 train = [[0, 0.0, 1], [0, 0.2, 1], [0, 0.5, 1], [0, 0.7, 1], [0, 0.9, 1], [0, 1.0, 1], [0, 1.2, 1], [0, 1.5, 1], [0, 1.7, 1], [0, 1.9, 1], [0, 2.0, 1], [0, 2.3, 1], [0, 2.5, 1], [0, 2.8, 1], [0, 3.0, 1], [0, 3.1, 1], [0, 3.5, 1], [0, 3.6, 1], [0, 3.7, 1], [0, 3.8, 1], # [0, 4.1, 1], [0, 4.3, 1], [0, 4.7, 1], [0, 4.8, 1], [0, 4.9, 1], [4, 0.0, 1], [4, 0.2, 1], [4, 0.5, 1], [4, 0.7, 1], [4, 0.9, 1], [4, 1.0, 1], [4, 1.2, 1], [4, 1.5, 1], [4, 1.7, 1], [4, 1.9, 1], [4, 2.0, 1], [4, 2.3, 1], [4, 2.5, 1], [4, 2.8, 1], [4, 3.0, 1], [4, 3.1, 1], [4, 3.5, 1], [4, 3.6, 1], [4, 3.7, 1], [4, 3.8, 1], # [4, 4.1, 1], [4, 4.3, 1], [4, 4.7, 1], [4, 4.8, 1], [4, 4.9, 1], [0.0, 4, 1], [0.2, 4, 1], [0.5, 4, 1], [0.7, 4, 1], [0.9, 4, 1], [1.0, 4, 1], [1.2, 4, 1], [1.5, 4, 1], [1.7, 4, 1], [1.9, 4, 1], [2.0, 4, 1], [2.3, 4, 1], [2.5, 4, 1], [2.8, 4, 1], [3.0, 4, 1], [3.1, 4, 1], [3.5, 4, 1], [3.6, 4, 1], [3.7, 4, 1], [3.8, 4, 1], # [4.1, 4, 1], [4.3, 4, 1], [4.7, 4, 1], [4.8, 4, 1], [4.9, 4, 1], [0.0, 0, 1], [0.2, 0, 1], [0.5, 0, 1], [0.7, 0, 1], [0.9, 0, 1], [1.0, 0, 1], [1.2, 0, 1], [1.5, 0, 1], [1.7, 0, 1], [1.9, 0, 1], [2.0, 0, 1], [2.3, 0, 1], [2.5, 0, 1], [2.8, 0, 1], [3.0, 0, 1], [3.1, 0, 1], [3.5, 0, 1], [3.6, 0, 1], [3.7, 0, 1], [3.8, 0, 1], # [4.1, 0, 1], [4.3, 0, 1], [4.7, 0, 1], [4.8, 0, 1], [4.9, 0, 1], [2, 1.0, 0], [2, 1.2, 0], [2, 1.5, 0], [2, 1.7, 0], [2, 1.9, 0], [2, 2.1, 0], [2, 2.3, 0], [2, 2.5, 0], [2, 2.7, 0], [2, 2.9, 0], [2.5, 1.0, 0], [2.5, 1.2, 0], [2.5, 1.5, 0], [2.5, 1.7, 0], [2.5, 1.9, 0], [2.5, 2.1, 0], [2.5, 2.3, 0], [2.5, 2.5, 0], [2.5, 2.7, 0], [2.5, 2.9, 0], [3, 1.0, 0], [3, 1.2, 0], [3, 1.5, 0], [3, 1.7, 0], [3, 1.9, 0], [3, 2.0, 0], [3, 2.2, 0], [3, 2.5, 0], [3, 2.7, 0], [3, 2.9, 0], ] trainData = np.array(train) # print(trainData) # trainX = trainData[:, 0] # trainY = trainData[:, 1] # trainZ = trainData[:, 2] trainX = trainData[:, 0:2] trainY = trainData[:, 2] # print(trainX) # print(trainY) #调用SVC() clf = svm.SVC() clf.fit(trainX,trainY) predict_y = clf.predict([[2.4,2.4],[2.3,1.3],[0.2,0.5],[3.7,3.5],[7,7]]) print(predict_y) #画图 drawX1 = np.linspace(0,4,100) drawY1 = np.linspace(0,4,100) drawX2,drawY2 = np.meshgrid(drawX1,drawY1) drawX3 = np.ravel(drawX2) drawY3 = np.ravel(drawY2) draw4 = np.c_[drawX3,drawY3] predict_drawZ1 = clf.predict(draw4) print(predict_drawZ1) drawX3_0 = drawX3[np.where(predict_drawZ1==0)] drawY3_0 = drawY3[np.where(predict_drawZ1==0)] drawX3_1 = drawX3[np.where(predict_drawZ1==1)] drawY3_1 = drawY3[np.where(predict_drawZ1==1)] drawPoint(drawX3_0, drawY3_0, 'r', '+') drawPoint(drawX3_1, drawY3_1, 'g', '+') #样本点 trainX2 = trainX[:,0] trainY2 = trainX[:,1] trainZ2 = trainData[:, 2] trainX2_0 = trainX2[np.where(trainZ2==0)] trainY2_0 = trainY2[np.where(trainZ2==0)] trainX2_1 = trainX2[np.where(trainZ2==1)] trainY2_1 = trainY2[np.where(trainZ2==1)] drawPoint(trainX2_0, trainY2_0, 'b', 'o') drawPoint(trainX2_1, trainY2_1, 'm', 'v') plt.show()