Python机器学习基础教程——1.7第一个应用:鸢尾花分类——学习笔记

1.7 第一个应用:鸢尾花分类

假设有一名植物学爱好者对她发现的鸢尾花的品种很感兴趣。她收集了每朵鸢尾花的一些测量数据:花瓣的长度和宽度以及花萼的长度和宽度,所有测量结果的单位都是厘米。
       她还有一些鸢尾花分类的测量数据,这些花之前已经被植物学专家鉴定为属于setosa(山鸢尾)、versicolor(杂色)或virginica(维尔吉妮卡)三个品种之一。对于这些测量数据,她可以确定每朵鸢尾花所属的品种。

我们的目标是构建一个机器学习模型,可以从这些已知品种的鸢尾花测量数据中进行学习,从而能够预测新鸢尾花的品种。
因为我们有已知的鸢尾花的测量数据,所以这是一个监督学习问题。在这个问题中,我们要在多个选项中预测其中一个(鸢尾花的品种)。这是一个分类(classification)问题的示例。可能的输出(鸢尾花的品种)叫做类别(class)。数据集中的每朵鸢尾花都属于三个类别之一,所以这是一个三分类问题。
       单个数据点(一朵鸢尾花)的预期输出是这朵花的品种。对于一个数据点来说,它的品种叫做标签(label)

1.7.1 初识数据

本例中我们用到了鸢尾花(Iris)数据集,这是机器学习和统计学中一个经典的数据集。它包含在scikit-learn的datasets模型中。我们可以调用load_iris函数来加载数据:

from sklearn.datasets import load_iris
iris_dataset=load_iris()
print("输出iris_dataset:\n{}".format(iris_dataset))



输出iris_dataset:
{'data': 
array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [4.6, 3.4, 1.4, 0.3],
       [5. , 3.4, 1.5, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.4, 3.7, 1.5, 0.2],
       [4.8, 3.4, 1.6, 0.2],
       [4.8, 3. , 1.4, 0.1],
       [4.3, 3. , 1.1, 0.1],
       [5.8, 4. , 1.2, 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [5.4, 3.9, 1.3, 0.4],
       [5.1, 3.5, 1.4, 0.3],
       [5.7, 3.8, 1.7, 0.3],
       [5.1, 3.8, 1.5, 0.3],
       [5.4, 3.4, 1.7, 0.2],
       [5.1, 3.7, 1.5, 0.4],
       [4.6, 3.6, 1. , 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [4.8, 3.4, 1.9, 0.2],
       [5. , 3. , 1.6, 0.2],
       [5. , 3.4, 1.6, 0.4],
       [5.2, 3.5, 1.5, 0.2],
       [5.2, 3.4, 1.4, 0.2],
       [4.7, 3.2, 1.6, 0.2],
       [4.8, 3.1, 1.6, 0.2],
       [5.4, 3.4, 1.5, 0.4],
       [5.2, 4.1, 1.5, 0.1],
       [5.5, 4.2, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.2],
       [5. , 3.2, 1.2, 0.2],
       [5.5, 3.5, 1.3, 0.2],
       [4.9, 3.6, 1.4, 0.1],
       [4.4, 3. , 1.3, 0.2],
       [5.1, 3.4, 1.5, 0.2],
       [5. , 3.5, 1.3, 0.3],
       [4.5, 2.3, 1.3, 0.3],
       [4.4, 3.2, 1.3, 0.2],
       [5. , 3.5, 1.6, 0.6],
       [5.1, 3.8, 1.9, 0.4],
       [4.8, 3. , 1.4, 0.3],
       [5.1, 3.8, 1.6, 0.2],
       [4.6, 3.2, 1.4, 0.2],
       [5.3, 3.7, 1.5, 0.2],
       [5. , 3.3, 1.4, 0.2],
       [7. , 3.2, 4.7, 1.4],
       [6.4, 3.2, 4.5, 1.5],
       [6.9, 3.1, 4.9, 1.5],
       [5.5, 2.3, 4. , 1.3],
       [6.5, 2.8, 4.6, 1.5],
       [5.7, 2.8, 4.5, 1.3],
       [6.3, 3.3, 4.7, 1.6],
       [4.9, 2.4, 3.3, 1. ],
       [6.6, 2.9, 4.6, 1.3],
       [5.2, 2.7, 3.9, 1.4],
       [5. , 2. , 3.5, 1. ],
       [5.9, 3. , 4.2, 1.5],
       [6. , 2.2, 4. , 1. ],
       [6.1, 2.9, 4.7, 1.4],
       [5.6, 2.9, 3.6, 1.3],
       [6.7, 3.1, 4.4, 1.4],
       [5.6, 3. , 4.5, 1.5],
       [5.8, 2.7, 4.1, 1. ],
       [6.2, 2.2, 4.5, 1.5],
       [5.6, 2.5, 3.9, 1.1],
       [5.9, 3.2, 4.8, 1.8],
       [6.1, 2.8, 4. , 1.3],
       [6.3, 2.5, 4.9, 1.5],
       [6.1, 2.8, 4.7, 1.2],
       [6.4, 2.9, 4.3, 1.3],
       [6.6, 3. , 4.4, 1.4],
       [6.8, 2.8, 4.8, 1.4],
       [6.7, 3. , 5. , 1.7],
       [6. , 2.9, 4.5, 1.5],
       [5.7, 2.6, 3.5, 1. ],
       [5.5, 2.4, 3.8, 1.1],
       [5.5, 2.4, 3.7, 1. ],
       [5.8, 2.7, 3.9, 1.2],
       [6. , 2.7, 5.1, 1.6],
       [5.4, 3. , 4.5, 1.5],
       [6. , 3.4, 4.5, 1.6],
       [6.7, 3.1, 4.7, 1.5],
       [6.3, 2.3, 4.4, 1.3],
       [5.6, 3. , 4.1, 1.3],
       [5.5, 2.5, 4. , 1.3],
       [5.5, 2.6, 4.4, 1.2],
       [6.1, 3. , 4.6, 1.4],
       [5.8, 2.6, 4. , 1.2],
       [5. , 2.3, 3.3, 1. ],
       [5.6, 2.7, 4.2, 1.3],
       [5.7, 3. , 4.2, 1.2],
       [5.7, 2.9, 4.2, 1.3],
       [6.2, 2.9, 4.3, 1.3],
       [5.1, 2.5, 3. , 1.1],
       [5.7, 2.8, 4.1, 1.3],
       [6.3, 3.3, 6. , 2.5],
       [5.8, 2.7, 5.1, 1.9],
       [7.1, 3. , 5.9, 2.1],
       [6.3, 2.9, 5.6, 1.8],
       [6.5, 3. , 5.8, 2.2],
       [7.6, 3. , 6.6, 2.1],
       [4.9, 2.5, 4.5, 1.7],
       [7.3, 2.9, 6.3, 1.8],
       [6.7, 2.5, 5.8, 1.8],
       [7.2, 3.6, 6.1, 2.5],
       [6.5, 3.2, 5.1, 2. ],
       [6.4, 2.7, 5.3, 1.9],
       [6.8, 3. , 5.5, 2.1],
       [5.7, 2.5, 5. , 2. ],
       [5.8, 2.8, 5.1, 2.4],
       [6.4, 3.2, 5.3, 2.3],
       [6.5, 3. , 5.5, 1.8],
       [7.7, 3.8, 6.7, 2.2],
       [7.7, 2.6, 6.9, 2.3],
       [6. , 2.2, 5. , 1.5],
       [6.9, 3.2, 5.7, 2.3],
       [5.6, 2.8, 4.9, 2. ],
       [7.7, 2.8, 6.7, 2. ],
       [6.3, 2.7, 4.9, 1.8],
       [6.7, 3.3, 5.7, 2.1],
       [7.2, 3.2, 6. , 1.8],
       [6.2, 2.8, 4.8, 1.8],
       [6.1, 3. , 4.9, 1.8],
       [6.4, 2.8, 5.6, 2.1],
       [7.2, 3. , 5.8, 1.6],
       [7.4, 2.8, 6.1, 1.9],
       [7.9, 3.8, 6.4, 2. ],
       [6.4, 2.8, 5.6, 2.2],
       [6.3, 2.8, 5.1, 1.5],
       [6.1, 2.6, 5.6, 1.4],
       [7.7, 3. , 6.1, 2.3],
       [6.3, 3.4, 5.6, 2.4],
       [6.4, 3.1, 5.5, 1.8],
       [6. , 3. , 4.8, 1.8],
       [6.9, 3.1, 5.4, 2.1],
       [6.7, 3.1, 5.6, 2.4],
       [6.9, 3.1, 5.1, 2.3],
       [5.8, 2.7, 5.1, 1.9],
       [6.8, 3.2, 5.9, 2.3],
       [6.7, 3.3, 5.7, 2.5],
       [6.7, 3. , 5.2, 2.3],
       [6.3, 2.5, 5. , 1.9],
       [6.5, 3. , 5.2, 2. ],
       [6.2, 3.4, 5.4, 2.3],
       [5.9, 3. , 5.1, 1.8]]), 
'target':
 array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
'target_names': 
array(['setosa', 'versicolor', 'virginica'], dtype='

load_iris返回的iris对象是一个Bunch对象,与字典非常相似,里面包含键和值:

load_iris()返回的是一个Bunch对象,有五个键:

①target_names: 鸢尾花的三个品种

②feature_names: 鸢尾花的四个特征

③DESCR: 对数据集的简要说明

④data: 鸢尾花四个特征的具体数据

⑤target: 鸢尾花的品种,由0,1,2来表示

print("输出Keys of iris_dataset:\n{}".format(iris_dataset.keys()))


输出Keys of iris_dataset:
dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names', 'filename'])
#dict_keys(['数据', '目标', '目标名称', '备注', '特征名称', '文件名'])
#数据:array([[5.1, 3.5, 1.4, 0.2],[4.9, 3. , 1.4, 0.2],[4.7, 3.2, 1.3, 0.2],......])
#目标:array([0, 0, ..., 0, 1, 1, ..., 1, 2, 2, ..., 2])
#目标名称: array(['setosa', 'versicolor', 'virginica'], dtype='

DESCR键对应的是数据集的简要说明,可以查看一些数据(这不是很重要,不要在意这些细节):

targte_names键对应的值时一个字符串数组,里面包含我们要预测的花的品种:

print("输出Target names:{}".format(iris_dataset['target_names']))

输出Target names:['setosa' 'versicolor' 'virginica']
#目标名称: array(['山鸢尾', '杂色鸢尾', '维尔吉妮卡鸢尾'], dtype='

feature_names键对应的值是一个字符串列表,对每一个特征进行了说明:

print("输出Feature names:\n{}".format(iris_dataset['feature_names']))

输出Feature names:
['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
#特征名称:['花萼长度(cm)', '花萼宽度(cm)', '花瓣长度(cm)', '花瓣宽度(cm)']

数据包含在targetdata字段中。data里面是花萼长度、花萼宽度、花瓣长度、花瓣宽度的测量是,格式为Numpy数组:

print("输出Type of data:{}".format(type(iris_dataset['data'])))

输出Type of data:

data数组的每一行对应一朵花,列代表每朵花的四个测量数据:

print("输出Shape of data:{}".format(iris_dataset['data'].shape))

输出Shape of data:(150, 4)

可以看出,数组中包含150多不同的花的测量数据。前面说过,机器学习中的个体叫作样本(sample),其属性叫作特征(feature)。data数组的形状(Shape)是样本数乘以特征数(150 * 4)。这是scikit-learn中的约定,你的数据形状应始终遵循这个约定。

我们看下前5个样本的特征数据:

print("输出前5个数据:\n{}".format(iris_dataset['data'][:5]))

输出前5个数据:
[[5.1 3.5 1.4 0.2]
 [4.9 3.  1.4 0.2]
 [4.7 3.2 1.3 0.2]
 [4.6 3.1 1.5 0.2]
 [5.  3.6 1.4 0.2]]

从数据中可以看出,前5朵花的花瓣宽度都是0.2cm,第一朵花的花萼最长,是5.1cm。


target数组包含的是测量过的每朵花的品种,也是一个Numpy数组:

print("输出Type of target:{}".format(type(iris_dataset['target'])))

输出Type of target:

target是一维数组,每朵花对应其中一个数据:

print("输出Shape of target:{}".format(iris_dataset['target'].shape))

输出Shape of target:(150,)

品种被转换成从0到2的整数

print("输出Targt:\n{}".format(iris_dataset['target']))

输出Targt:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]
# 上述数字的代表含义由iris['target_names']数组给出:
# 0代表setosa,1代表versicolor,2代表virginica。

1.7.2 衡量模型是否成功:训练数据与测试数据

数据应当分为两个部分

一部分数据用于构建机器学习模型,叫作训练数据(training data)训练集(training set)

其余的数据用来评估模型性能,叫做测试数据(test data)测试集(test set)留出集(hold-out set)

scikit-learn中的train_test_split函数可以打乱数据集并进行拆分

这个函数将75%的数据作为训练集,25%的数据作为测试集。(比例可以随意分配,但75:25较为常用)

scikit-learn中,数据(本例中数据是花的测量数据(花瓣、花萼的长和宽))通常用大写X表示,

而标签(本例中数据是花的种类['setosa' 'versicolor' 'virginica'])用小写y表示

这是收到数学标准公式的“y=f(X)”的启发,其中x是函数的输入,y是函数的输出。

用大写X是因为数据是一个二维数组(矩阵),

用小写y是因为目标是一个一位数组(向量),这也是数学中的约定

对数据调用train_test_split函数,并对输出结果采用下面这种命名方法:

from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(iris_dataset['data'],iris_dataset['target'],random_state=0)

在数据进行拆分前,train_test_split函数利用为随机数生成器见数据集打乱,确保测试集中包含所有类别的数据。

为了确保多次运行同一函数能够得到相同的输出,我们利用random_state参数指定了随机数生成器的种子,

这样函数输出是固定不变的,所以这行代码的输出始终相同。

train_test_split函数的输出为X_train,X_test,y_train,y_test,他们都是Numpy数组

print("输出X_train shape:{}".format(X_train.shape))
print("输出y_train shape:{}".format(y_train.shape))
print("输出X_test shape:{}".format(X_test.shape))
print("输出y_test shape:{}".format(y_test.shape))


输出X_train shape:(112, 4)
输出y_train shape:(112,)
输出X_test shape:(38, 4)
输出y_test shape:(38,)

1.7.3 要事第一:观察数据

在构建机器学习模型之前,通常最好检查一下数据,看看如果不用机器学习能不能轻松完成任务,或者需要的信息有没有包含在数据中。

检查数据也是发现异常值和特殊值的好方法。

检查数据最佳方法之一就是将其可视化。

一种可视化方法是绘制散点图(scatter plot)。

数据散点图将一个特征作为x轴,另一个特征作为y轴,将每一个数据点绘制为图上的一个点。

不幸的是,计算机屏幕只有两个维度,所以我们一次只能绘制两个特征(也可能是3个)。

用这种方法很难对多于3个特征的数据集作图。

解决这个问题的一种方法是绘制散点图矩阵(pair plot),从而可以两两查看所有的特征。

下图是训练集中特征的散点图矩阵。数据点的颜色与鸢尾花的品种对应。

为了绘制这张图,我们先将Numpy数组转换成pandas DateFrame。

pandas有一个绘制三点图矩阵的函数,叫做“scatter_matrix”

矩阵的对教师每个特征的直线图

由于书中采用的pd.scatter_matrix()似乎已停止更新,故此采用Jupyter Notebook推荐的pd.plotting.scatter_matrix进行绘图

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import mglearn
#Anaconda 3并未默认安装mglearn,需要打开anaconda prompt输入pip install mglearn进行安装
#Python中添加mglearn库的方法:
#(1)开始——Anaconda——打开Anaconda Prompt
#(2)输入pip install mglearn(自动安装)
#(3)输入conda list,检查有无mglearn,有则成功

# 利用X_train中的数据创建DataFrame
# 利用iris_dataset.feature_names中的字符串对数据进行标记
iris_dataframe=pd.DataFrame(X_train,columns=iris_dataset.feature_names)
# 利用DataFrame创建散点图矩阵,按y_train着色
grr=pd.plotting.scatter_matrix(iris_dataframe,c=y_train,figsize=(15,15),marker='o',hist_kwds={'bins':20},s=60,alpha=0.8,cmap=mglearn.cm3)
#由于书中采用的pd.scatter_matrix()似乎已停止更新,故此采用Jupyter Notebook推荐的#pd.plotting.scatter_matrix进行绘图

plt.show()
#pycharm要用plt.show()显示图片

输出的散点图矩阵:petal length(花瓣长度)

Python机器学习基础教程——1.7第一个应用:鸢尾花分类——学习笔记_第1张图片

介绍一下scatter_matrix()各参数的含义

pandas.plotting.scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, diagonal='hist', marker='.', density_kwds=None, hist_kwds=None, range_padding=0.05, **kwds)

  • frame : 所要展示的pandas的DataFrame对象
  • alpha : 透明度,一般取(0, 1]
  • figsize : 以英寸为单位的图像尺寸,以(width, height)的形式设置 
  • ax : 一般为none
  • grid : 布尔型,控制网格的显示
  • diagonal : 须在{'hist', 'kde'}中选取一个作为参数,'hist'表示直方图,'kde'表示核密度估计
  • marker : 散点标记的类型,可选'.'或 ','或'o',默认为'.'
  • hist_kwds : 与hist相关的可变参数
  • density_kwds : 与kde相关的可变参数
  • range_padding : 图像在x轴、y轴附近的留白,默认为0.05
  • kwds : 其他可变参数
  • 还有一些代码中用到的可变参数:
  • c : 将相同的值划分为相同的颜色
  • cmap : 配色方案,代码中采用了mglearn中的方案
  • s : 散点标记的大小

从上图可以看出,利用花瓣(petal)和花萼(sepal)的测量数据基本可以将三个类别区分开。

这说明机器学习模型很可能可以学会区分它们。

1.7.4 构建第一个模型:k近邻算法

采用算法:k近邻算法

k近邻算法:要对一个新的数据点作出预测,k近邻算法会在数据集中寻找与这个点最近的数据点,然后将找到的数据点的标签值(目标值)赋给这个新的数据点。

k近邻算法中k的含义是,我们可以考虑训练集中与新数据点最近的任意k个邻居(比如说,距离最近的3个或5个邻居),而不是只考虑最近的那一个。然后,我们可以用这些邻居中数量做多的类别做出预测。

k近邻算法在sklearn的neighbors模块中的KNeighboursClassifier类中实现。KNeighboursClassifier最重要的参数就是k,k指的是考虑训练集中与新数据点最近的任意k个邻居,这里我们设为1

from sklearn.neighbors import KNeighborsClassifier
knn=KNeighborsClassifier(n_neighbors=1)

knn对象对算法进行了封装,既包括用训练数据构建模型的算法,也包括对新数据点进行预测的算法。它还包括算法从训练数据中提取的信息。对于KNeighborsClassifier来说,里面只保存了训练集。

想要基于训练集来构建模型,需要调用knn对象的fit()方法,输入参数为X_train和y_train,二者都是Numpy数组,前者包含训练数据,后者包含相应的训练标签。

knn.fit(X_train,y_train)
print("输出knn:\n{}".format(knn))


输出knn:
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
                     metric_params=None, n_jobs=None, n_neighbors=1, p=2,
                     weights='uniform')

1.7.5 做出预测

现在我们可以用这个模型对新数据进行预测了,我们可能并不知道,这些新数据的正确标签。

想像一下,我们在野外发现了一朵鸢尾花,花萼长5cm宽2.9cm,花瓣长1cm宽0.2cm。这朵花应该属于哪个品种呢?

我们可以将这些数据放在一个Numpy数组里,再次计算形状,数组形状为:样本数1*特征数4

X_new=np.array([[5,2.9,1,0.2]])
print("输出X_new.shape:{}".format(X_new.shape))



输出X_new.shape:(1, 4)

注意,我们将这朵花的测量数据转换为二维Numpy数组的一行,这是因为scikit-learn的输入数据必须是二维数组。

prediction=knn.predict(X_new)
print("输出Prediction:{}".format(prediction))
print("输出Predicted target name:{}".format(iris_dataset['target_names'][prediction]))


输出Prediction:[0]
输出Predicted target name:['setosa']
#根据我们模型的预测,野外这朵鸢尾花属于类别0,也就是说他属于setosa(山鸢尾花)

1.7.6 评估模型

我们可以对测试数据中的每朵鸢尾花进行预测,并将预测结果与表情(已知的品种)进行对比。

我们可以通过计算精度(accuracy)来衡量模型的优劣,精度就是品种预测正确的花所占的比例:

我们可以使用knn对象的score方法来计算测试集的精度:

print("输出Test set sore:{:.2f}".format(knn.score(X_test,y_test)))


输出Test set sore:0.97

对于这个模型来说,测试集的精度约为0.97,也即是说,对于测试集中的鸢尾花,我们的预测有97%是正确的。根据一些数据假设,对于新的鸢尾花,可以认为我们的模型预测结果有97%都是正确的。对于我们的植物学爱好者应用程序来说,高精度意味着模型足够可信,可以使用。

1.8 小结与展望

1.鸢尾花的分类是一个监督学习问题,它有三个品种,因此又是一个三分类问题。
2.我们将数据集分成训练集(training set)和测试集(test set),前者用于构建模型,后者用于评估模型对前所未见的新数据的泛化能力。
3.我们选择了k近邻分类算法,根据新数据点在训练集中距离最近的邻居进行预测。

核心步骤是:数据集拆分→选取模型→训练模型→评估模型

核心代码:这段代码包含了应用scikit-learn中任何机器学习算法的核心代码

fit()、predict()、score()方法是scikit-learn监督学习模型中最常用的接口

X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], random_state=0)
 
knn = KNeighborsClassifier(n_neighbors=1)
 
knn.fit(X_train, y_train)
 
print("Test set score: {:.2f}".format(knn.score(X_test, y_test)))

完整代码:

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import mglearn
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
 
iris_dataset = load_iris() #鸢尾花数据集
 
X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], random_state=0)
 #数据拆分,最佳比例是数据集:测试集 = 3:1
 
iris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)
grr = pd.plotting.scatter_matrix(iris_dataframe, c=y_train, figsize=(15, 15), marker='o', 
	hist_kwds={"bins": 20}, s=60, alpha=.8, cmap=mglearn.cm3)  #展示散点图矩阵

#plt.show()
#pycharm要用plt.show()显示图片
 
knn = KNeighborsClassifier(n_neighbors=1) #knn对算法进行了封装,包含了模型构建算法与预测算法
 
knn.fit(X_train, y_train) #构建模型
 
X_new = np.array([[5, 2.9, 1, 0.2]])
prediction = knn.predict(X_new)
 
print("Test set score: {:.2f}".format(knn.score(X_test, y_test)))

 

你可能感兴趣的:(鸢尾花分类模型)