SVM鸢尾花分类

最近学习了SVM支持向量机模型,并用uci上的数据集做了简单的鸢尾花分类。

svm数学原理SVM鸢尾花分类_第1张图片
SVM鸢尾花分类_第2张图片
代码

from sklearn import svm
import numpy as np
from sklearn import model_selection
import matplotlib.pyplot as plt

#导入数据集
path = "/Users/lkr/Documents/deep_learning/SVM/iris.data" #文件路径

#因为数据集最后一列是花的种类,要将string类型转换成float类型,定义一个转换函数,每个花的种类是一个数字分别是0,1,2
def iris_type(s):
    it = {b'Iris-setosa': 0, b'Iris-versicolor': 1, b'Iris-virginica': 2}
    return it[s]

#转换数据 (路径, 数据类型float, 以逗号分割, 将最后一列数据用上述函数进行转换)
data = np.loadtxt(path, dtype=float, delimiter=",", converters={4: iris_type})

#老规矩将数据分为训练集和测试集
X, y = np.split(data, (4,), axis=1) #第五列为y,之前是X
y=y.reshape((-1)) #将y变成一维list,否则会报warning
x = X[:, 0:2] #x数据是将X再次划分,0,1 这两列为x特征
# 30%作为测试集, 随机种子数设为1
x_train, x_test, y_train, y_test = model_selection.train_test_split(x, y, random_state=1, test_size=0.3)


#上面将数据集处理好后,下面进行模型搭建

#核函数
"""
其中kernel表示用什么核函数,这里用最常用的rbf
gamma是核函数rbf的参数,设置为0.1
decision_function_shape是选择one vs one,属于svm多分类问题中的一对一分类
C是C-SVM惩罚参数C
C越大,相当于惩罚松弛变量,希望松弛变量接近0,即对误分类的惩罚增大,趋向于对训练集全分对的情况,
这样对训练集测试时准确率很高,但泛化能力弱。C值小,对误分类的惩罚减小,允许容错,将他们当成噪声点,泛化能力较强。
"""
clf = svm.SVC(kernel='rbf', gamma=0.1, decision_function_shape='ovo', C=0.8)

#训练
clf.fit(x_train, y_train)
#打印训练的准确率
print(clf.score(x_train, y_train))

#打印测试的准确率
print(clf.score(x_test, y_test))

#将原始结果和预测结果放在一起对比,更加直观反映
y_train_hat=clf.predict(x_train) #y预测结果
y_train_1d=y_train.reshape((-1)) #数据集中y的原本结果
comp=zip(y_train_1d,y_train_hat) #二者放一起
print(list(comp))

#可视化
#参数c表示要按照y_train中的类别对散点进行上色,edgecolor表示散点的边的颜色,s表示散点的大小
plt.figure(dpi=80, figsize=(15,8))

plt.subplot(121)
plt.title('Training data', fontsize=20)
plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train.reshape((-1)), edgecolors='k',s=50)
plt.xlabel('sepal length', fontsize = 14)
plt.ylabel('sepal width', fontsize = 14)

plt.subplot(122)
plt.title('Training data prediction', fontsize=20)
plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train_hat.reshape((-1)), edgecolors='k',s=50)
plt.xlabel('Calyx length', fontsize = 14)
plt.ylabel('Calyx width', fontsize = 14)
plt.show()

源码和文件下载地址
https://github.com/kr-li/SVM

你可能感兴趣的:(python,可视化,深度学习)