Sklearn闲谈:基于鸢尾花数据集实现KNN算法

目录

  • 写在前面
    • Sklearn介绍
    • KNN介绍
  • 实现流程
    • 导入鸢尾花数据集
    • 交叉验证确定K的值
    • 数据的输入与拟合
    • 数据可视化实现
    • 完整代码

写在前面

总体来说算是第一篇博客吧,博主只是一名工业工程专业的菜鸡,接触python没多久,但由于导师的项目需求迫不得已地开始接触这个门槛极高的机器学习领域,python是一个好语言,在数据处理方面有着其他语言难以企及的优势,何况诸多国内与海外的大佬们开发了许多便利的库,但可惜运算速度一直都是解释性语言的硬伤……扯远了,这篇文章将对python的KNN算法的实现作简要概述,希望可以帮助到完全萌新的小白,不过由于博主知识有限,这里只对具体用法作简要解释说明,不会对算法本身实现进一步探讨。

Sklearn介绍

Sklearn是用于预测数据分析的简单而有效的工具,这些工具基于Numpy、SciPy和matplotlib开放源代码,是一种功能较为完善,使用较为简单的机器学习库,安装方式与其他库的安装方式相同,这里不作过多赘述。

KNN介绍

KNN算法又叫最近邻算法,顾名思义,就是把邻近的数据判断为一种类型,在sklearn中,knn参数如下

KNeighborsClassifier(n_neighbors = 5,weights = 'uniform',algorithm = '',leaf_size = '30',p = 2,metric = 'minkowski',metric_params = None,n_jobs = None)

在不输入任何参数的情况下,该函数会根据默认值执行,当然一般情况下也不需要人为干预太多,往往地为了提高模型识别准确性,我们只需要确定里面的’n_neighbors’值(也就是KNN中所谓的K)就可以了,想要了解各个参数的具体含义,可以参考这篇博客(当然不是我的)

实现流程

导入鸢尾花数据集

Sklearn中已经自带了iris数据集,所以并不需要在安装完之后特意去寻找.csv文件,导入数据集的代码是

from sklearn.datasets import load_iris
iris = load_iris()#读取鸢尾花数据集
#print(iris)

只需要两行代码就可以完成数据集的导入,如果需要查看数据集中的内容则可以通过"print"打印出来。
iris中主要有三种数据:‘data’,‘target’,‘target_names’,分别是数据,标签以及标签名称,data和target对应,也就是说在下标相同的情况下,比如data中的第一个数据和target中的第一个数据所描述的都是一朵鸢尾花。
接下来将data和target分别提取出来

x = iris.data
y = iris.target

这样数据就提取出来了,不严谨地来说,单单靠着这两个数据就能直接训练KNN模型。但是,前文中提到了K的问题,如果选取K过小,那么会发生过拟合,如果选取K过大则又会欠拟合,下面介绍一下K怎么确定。

交叉验证确定K的值

由于样本数据较小,所以我们可以采用交叉验证的方法

from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score 
k_error = []
for k in range(1,31):
	knn = KNeighborsClassifier(n_neighbors=k)
	scores = cross_val_score(knn,x,y,cv=5,scoring='accuracy')
	k_error.append(1-scores.mean())

cross_val_score()就是交叉验证了,将data和target输入,通过设置cv参数来确定训练集和测试集(划分比例是cv:1),然后就能得到一个正确率,关于scoring的选择则参考官方文档这里不作赘述(其实也不太晓得)
好了,测试的结果出来了,错误率已经被append到了k_error数组里面,接下来只需要找到里面的最小值,然后导出下标,那个下标就是我们要用到的“K”
那,怎么导出下标勒?这里用的是一种比较笨的方法

K = {num:i for i,num in enumerate(k_error)}[min(k_error)]

没错,就是把它存储到字典里面然后导出……当然有更好的方法还是比较建议用其他方法的,如果数据很多的话这种方法会挺浪费运算速度的。

数据的输入与拟合

有了K,其实问题已经解决一大半了,剩下的就是用sklearn自带的方法来完成数据的拟合与新数据的判断,KNN模型的建立已经在前文中提到过了

clf = KNeighborsClassifier(n_neighbors = K)
X = x[:,:3]
clf.fit(X,y)

好了,三行代码数据拟合完了,不过这里的数据只选取了前三列(鸢尾花数据集有四列数据),主要是为了方便后面的可视化实现,毕竟暂时没有办法弄出来四维坐标轴嘛……
然后输入一个数据,让模型判断下这是个啥

point = [[7.1,3,4]]
answer = clf.predict(point)
print("结果是%d" % answer)

输出结果
emmm,结果是1,也就是说它的名字是’target_names[1]’,但是图都没有不好判断欸。

数据可视化实现

代码和数据都是抽象的,为了更为直观地‘观测’到,必须将它用图表的形式呈现出来,为了提高美观性(逼格),这里用三维图呈现出来。

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

figure = plt.figure().add_subplot(projection='3d')
figure.scatter(x[:,0],x[:,1],x[:,2],c=y)#绘制散点图,颜色直接用y里面的数据
figure.set_xlabel("1st eigenvector")
figure.set_ylabel("2nd eigenvector")
figure.set_zlabel("3rd eigenvector")
plt.show()

Sklearn闲谈:基于鸢尾花数据集实现KNN算法_第1张图片
然后再把刚才的点也给捣鼓出来

figure.scatter(point[0][0],point[0][1],point[0][2],c = 'r')
plt.show()

嗯,为了显眼点这里就用红色了
Sklearn闲谈:基于鸢尾花数据集实现KNN算法_第2张图片

完整代码

完整代码如下

from sklearn.datasets import load_iris
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from mpl_toolkits.mplot3d import Axes3D

iris = load_iris()
x = iris.data
y = iris.target
k_range = range(1,31)
k_error = []
for k in k_range:
    knn = KNeighborsClassifier(n_neighbors=k)
    scores = cross_val_score(knn,x,y,cv=5,scoring='accuracy')
    #5:1划分训练集和验证集,进行交叉验证得出评估分数
    #print("%d:%s" % (k,scores))
    k_error.append(1-scores.mean())

#plt.plot(k_range,k_error)
#plt.xlabel('Value of K for KNN')
#plt.ylabel('Error')
#plt.show()
'''
上述代码求出不同k取值对误差的影响
'''
K = {num:i for i,num in enumerate(k_error)}[min(k_error)]
#print(n_neighbors)
#将k_error中最小值的下标取出,将它作为K(代码参考Leecode:Two Sum)
figure = plt.figure().add_subplot(projection='3d')
figure.scatter(x[:,0],x[:,1],x[:,2],c=y)
figure.set_xlabel("1st eigenvector")
figure.set_ylabel("2nd eigenvector")
figure.set_zlabel("3rd eigenvector")
#三维图形绘制
clf = KNeighborsClassifier(n_neighbors = K)

X = x[:,:3]
clf.fit(X,y)

point = [[7.1,3,4]]

answer = clf.predict(point)
figure.scatter(point[0][0],point[0][1],point[0][2],c = 'r')
print("结果是:%d" % answer)
plt.show()
#预测输入点,并且输出结果

参考
·https://scikit-learn.org/stable/index.html

你可能感兴趣的:(Sklearn闲谈:基于鸢尾花数据集实现KNN算法)