python进行KNN算法分析实战(鸢尾花数据集)

KNN算法分析实战(鸢尾花数据集)

目录

KNN算法分析实战(鸢尾花数据集)

 

代码效果图

一、导入需要的包

二、

1.导入数据

 

2.建立训练集和测试集

3.设置K值

4. 十重交叉验证K值

5.模型拟合 

6.数据可视化输出


代码效果图


,废话不多说,先看看代码实验结果

python进行KNN算法分析实战(鸢尾花数据集)_第1张图片

 

python进行KNN算法分析实战(鸢尾花数据集)_第2张图片

python进行KNN算法分析实战(鸢尾花数据集)_第3张图片

python进行KNN算法分析实战(鸢尾花数据集)_第4张图片


提示:以下是本篇文章正文内容,下面案例可供参考

一、导入需要的包

要是报错的话可以在pycharm安装包,要是不行就在命令窗口输入pip install +包名

import matplotlib.pyplot as plt
from sklearn import neighbors
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn import model_selection
from sklearn import metrics

二、

1.导入数据

导入数据并查看前5行代码

df1 = pd.read_csv(r'D:\python\iris.csv')
print(df1.head())#输出前五行
predictors = df1.columns[:-1]

 

python进行KNN算法分析实战(鸢尾花数据集)_第5张图片

2.建立训练集和测试集

代码如下:

x_train,x_test,y_train,y_test=model_selection.train_test_split(
    df1[predictors],df1.Species,
    test_size=0.5,
    random_state = 1234
)
print(np.ceil(np.log2(df1.shape[0])))

3.设置K值

#设置待测试的不同K值
K = np.arange(1,np.ceil(np.log2(df1.shape[0])))
print(np.arange(1,np.ceil(np.log2(df1.shape[0]))))
#设置空列表,用于储存平均准确率
accuracy = []

4. 十重交叉验证K值

使用十重交叉验证K值,并做出最适合K值的折线图

#使用十重交叉验证的方法
for k in K:
    cv_result = model_selection.cross_val_score\
        (neighbors.KNeighborsClassifier(n_neighbors=int(k),
                                        weights='distance'),
         x_train, y_train, cv=10, scoring='accuracy')
    accuracy.append(cv_result.mean())

#从K个平均准确率中挑选出最大值做对应的目标
arg_max = np.array(accuracy).argmax()
#中文负号正常显示
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
#绘制不同k值与准确率之间的折线图
plt.plot(K,accuracy)
plt.scatter(K,accuracy)
plt.text(K[arg_max],accuracy[arg_max],'最佳K值为%s'%int(K[arg_max]))
plt.show()

 python进行KNN算法分析实战(鸢尾花数据集)_第6张图片

5.模型拟合 

代入K值,进行模型拟合

#重新构建模型,并将最佳邻近数个数设置为7
knn_class = neighbors.KNeighborsClassifier(n_neighbors=7,weights='distance')
#模型拟合
knn_class.fit(x_train,y_train)
#模型在测试集上的预测
predict = knn_class.predict(x_test)

6.数据可视化输出

#构建混淆矩阵
cm = pd.crosstab(predict,y_test)
print(f'鸢尾花种类混淆矩阵\n{cm}')
#热力图输出
cm = pd.DataFrame(cm,columns=['setosa','versicolor','virginica'],
                  index=['setosa','versicolor','virginica'])
sns.heatmap(cm,annot=True,cmap='GnBu')
plt.xlabel('Real Lable')
plt.ylabel('Predict Lable')
plt.title('鸢尾花种类热力图')
plt.show()
#显示各类预测准确率
b = metrics.classification_report(y_test,predict)
print(f'显示各类预测准确率\n{b}')

 

 python进行KNN算法分析实战(鸢尾花数据集)_第7张图片python进行KNN算法分析实战(鸢尾花数据集)_第8张图片

 

你可能感兴趣的:(大数据,python,数据分析,机器学习,数据挖掘)