在西瓜数据集 3.0α 上分别用线性核和高斯核训练一个 SVM,并比较其支持向量的差别。
数据集下载地址:
https://amazecourses.obs.cn-north-4.myhuaweicloud.com/datasets/watermelon_3a.csv
任选数据集中的一种分布类型的数据,分别用软、硬间隔SVM和各类核函数训练,并分析他们分类的效果。
数据集下载地址:https://amazecourses.obs.cn-north-4.myhuaweicloud.com/datasets/SVM.zip
此博客为第二问,各类SVM的实现。上一篇博客为https://blog.csdn.net/qq_44459787/article/details/111409314
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2020/12/19 20:48
# @Author : Ryu
# @Site :
# @File : processing_data.py
# @Software: PyCharm
import numpy as np
from sklearn.model_selection import train_test_split
def npz_read(file_dir):
npz = np.load(file_dir)
data = npz['data']
label_list = npz['label']
npz.close()
return data, label_list
def split_train_test(data, label_list):
xtrain, xtest, ytrain, ytest = train_test_split(data, label_list, test_size=0.3)
return xtrain, xtest, ytrain, ytest
if __name__ == '__main__':
file_name = r'D:\Pythonwork\FisherLDA\SVM\2\分布1.npz'
data, label_list = npz_read(file_name)
xtrain, xtest, ytrain, ytest = split_train_test(data, label_list)
npz为数据集自带的解析算法。分类数据集依然使用的是sklearn的split方法。
2. 核函数实现
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2020/12/19 21:07
# @Author : Ryu
# @Site :
# @File : SVM.py
# @Software: PyCharm
from sklearn import svm
from sklearn.metrics import accuracy_score
from SVM2.processing_data import *
if __name__ == '__main__':
file_name = 'D:\Pythonwork\FisherLDA\SVM\SVM2\分布1.npz'
data, label_list = npz_read(file_name)
xtrain, xtest, ytrain, ytest = split_train_test(data, label_list)
# 线性核处理
linear_svm = svm.LinearSVC(C=0.5, class_weight='balanced')
linear_svm.fit(xtrain, ytrain)
y_pred = linear_svm.predict(xtest)
print('线性核的准确率为:{}'.format(accuracy_score(y_pred=y_pred, y_true=ytest)))
# 高斯核处理
gauss_svm = svm.SVC(C=0.5, kernel='rbf', class_weight='balanced')
gauss_svm.fit(xtrain, ytrain)
y_pred2 = gauss_svm.predict(xtest)
print('高斯核的准确率: %s' % (accuracy_score(y_pred=y_pred2, y_true=ytest)))
# 多项式核
poly_svm = svm.SVC(C=0.5, kernel='poly', degree=3, gamma='auto', coef0=0, class_weight='balanced')
poly_svm.fit(xtrain, ytrain)
y_pred3 = poly_svm.predict(xtest)
print('多项式核的准确率: %s' % (accuracy_score(y_pred=y_pred3, y_true=ytest)))
# sigmoid核
sigmoid_svm = svm.SVC(C=0.5, kernel='sigmoid', degree=3, gamma='auto', coef0=0, class_weight='balanced')
sigmoid_svm.fit(xtrain, ytrain)
y_pred4 = sigmoid_svm.predict(xtest)
print('sigmoid核的准确率: %s' % (accuracy_score(y_pred=y_pred4, y_true=ytest)))
#sigmoid核硬间隔
sigmoid_hard_svm = svm.SVC(C=1000000, kernel='sigmoid', degree=3, gamma='auto', coef0=0, class_weight='balanced')
sigmoid_hard_svm.fit(xtrain, ytrain)
y_pred5 = sigmoid_hard_svm.predict(xtest)
print('sigmoid核硬间隔的准确率: %s' % (accuracy_score(y_pred=y_pred5, y_true=ytest)))
实现部分主要是实现四类核函数和硬间隔支持向量机。
3. 结果分析
经过多次试验之后,这张结果记录是出现频率最多的。线性核的分类效果很不错,但是并没有高斯核和多项式核的稳定,sigmoid核的分类效果相当的差,但是软间隔的分类效果还是相较于硬间隔更好。
此图的样本数据分类效果本身就很明显,核函数的效果好也是在期望之内的。