KNN实现鸢尾花分类

因为我们有已知品种的鸢尾花的测量数据,所以这是一个监督学习问题。在这个问题中,我们要在多个选项中预测其中一个(鸢尾花的品种)。这是一个分类问题,可能的输出(鸢尾花的不同品种)叫做类别(class)。数据集中的每朵鸢尾花都属于三个类别之一,所以这是一个三分类问题。单个数据点(一朵鸢尾花)的预期输出是这朵花的品种。对于一个数据点来说,它的品种叫做标签(label)。

一、加载数据集

鸢尾花(Iris)数据集包含在scikit-learn的datasets模块中,我们可以调用load_iris函数来加载数据集:

from sklearn.datasets import load_iris
iris_dataset = load_iris()

load_iris返回的iris对象是一个Bunch对象,与字典非常相似,里面包含键和值:

print('Keys of iris_dataset:\n{}'.format(iris_dataset.keys()))

 target_names键对应的是一个字符串数组,里面包含我们要预测的花的品种:

iris_dataset['target_names']

feature_names键对应的值是一个字符串列表,对每一个特征进行了说明:

KNN实现鸢尾花分类_第1张图片

iris_dataset['feature_names']

KNN实现鸢尾花分类_第2张图片 data里面是花萼长度、花萼宽度、花瓣长度、花瓣宽度的测量数据,格式为Numpy数组,data数组的每一行对应一朵花,列代表每朵花的4个测量数据:

type(iris_dataset['data'])
iris_dataset['data'].shape

KNN实现鸢尾花分类_第3张图片

 target数组包含的是测量过的每朵花的品种,也是一个Numpy数组,它是一维数组,每朵花对应其中一个数据:

type(iris_dataset['target'])
iris_dataset['target'].shape

KNN实现鸢尾花分类_第4张图片

品种被转换成了0-2的整数,0代表setosa,1代表versicolor,2代表virginica。

iris_dataset['target']

KNN实现鸢尾花分类_第5张图片

 DESCR键对应的值是数据集的简要说明,这里给出开头的部分:

print(iris_dataset['DESCR'][:193]+'\n...')

KNN实现鸢尾花分类_第6张图片

 filename可以看到下载的iris数据集的文件的地址:

iris_dataset['filename']

 二、训练数据与测试数据

train_test_split函数可以打乱数据集并进行拆分,这个函数将75%的行数据及对应标签作为训练集,剩下25%的数据及其标签作为测试集。scikit-learn中的数据通常用大写的X表示,而标签用小写的y表示。大写的X是一个二维数组(矩阵),小写的y是因为目标是一个一维数组(向量)。在对数据进行拆分之前,train_test_split函数利用伪随机数生成器将数据集打乱。我们利用random_state参数指定了随机数的种子。

from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(iris_dataset['data'],iris_dataset['target'],random_state=0)

KNN实现鸢尾花分类_第7张图片

 三、数据可视化

绘制散点图矩阵,可以两两查看所有的特征,散点图矩阵无法同时显示所有特征之间的关系。

我们首先将Numpy数组转换成pandas Dataframe,pandas有一个绘制散点图矩阵的函数,叫做scatter_matrix。矩阵的对角线是每个特征的直方图:

import pandas as pd
import mglearn
# 利用X_train中的数据创建dataframe
# 利用iris_dataset.feature_names中的字符串对数据列进行标记
iris_dataframe = pd.DataFrame(X_train,columns = iris_dataset.feature_names)
# 利用dataframe创建散点图矩阵,按y_train着色
grr = pd.plotting.scatter_matrix(iris_dataframe,c=y_train,figsize(15,15),marker='o',
                              hist_kwds={'bins':20},s=60,alpha=.8,cmap=mglearn.cm3)

KNN实现鸢尾花分类_第8张图片

KNN实现鸢尾花分类_第9张图片

 四、构建模型:K近邻算法

k近邻算法的含义是,我们可以考虑训练集中与新数据点最近的任意k个邻居,然后我们利用这些邻居中数量最多的类别做出预测。

k近邻算法是在neighbors模块中的KNeighborsClassifier类中实现的,我们需要将这个类实例化为一个对象,才能使用这个模型。

from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=1)

knn对象对算法进行了封装,既包括用训练数据构建模型的算法,也包括对新数据点进行预测的算法。还包括算法从训练数据中提取的信息。对于KNeighborsClassifier来说,里面只保存了训练集。

想要基于训练集来构建模型,需要调用knn对象的fit方法,输入参数为X_train和y_train,二者都是numpy数组,前者包含训练数据,后者包括训练标签:

knn.fit(X_train,y_train)

fit方法返回的是knn对象本身并做原处修改,因此我们得到了分类器的字符串表示。

五、做出预测

import numpy as np
X_new = np.array([[5,2.9,1,0.2]])
prediction = knn.predict(X_new)
prediction
iris_dataset['target_names'][prediction]

KNN实现鸢尾花分类_第10张图片

根据我们的预测,这朵新的鸢尾花属于类别0,也就说它属于setosa品种。

 六、评估模型

y_pred = knn.predict(X_test)
y_pred
np.mean(y_pred == y_test)
knn.score(X_test,y_test)

KNN实现鸢尾花分类_第11张图片

 对于这个模型来说,测试集的精度约为0.97.

采用决策树分类鸢尾花数据集,可以参考Iris数据集实战 - 徐-清风 - 博客园

你可能感兴趣的:(机器学习,分类,机器学习,人工智能)