利用knn算法实现水果识别

# -*- coding: utf-8 -*-

"""
    案例:水果识别
    任务:通过knn算法对水果进行识别
"""


import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

# 1.数据加载
fruits_df = pd.read_table('./data/fruit_data_with_colors.txt')
print(fruits_df.info())
print()
print(fruits_df.describe())
print()
print(fruits_df.head())

# 2.创建目标标签和名称的字典
fruit_name_dict = dict(zip(fruits_df['fruit_label'], fruits_df['fruit_name']))

# 3.划分数据集
X = fruits_df[['mass', 'width', 'height', 'color_score']]
y = fruits_df['fruit_label']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1/4, random_state=0)
print("\n数据集样本数:{},训练集样本数:{},测试集样本数:{}".format(X.shape[0], X_train.shape[0], X_test.shape[0]))

# 4.建立选择模型
knn = KNeighborsClassifier(n_neighbors=5)

# 5.训练模型
knn.fit(X_train, y_train)

# 6.测试模型
y_pred = knn.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print("\nk为5时准确率为:",acc)

# 7.查看k值对结果的影响
k_range = range(1, 20)
acc_scores = []

for k in k_range:
    knn = KNeighborsClassifier(n_neighbors=k)
    knn.fit(X_train, y_train)
    acc_scores.append(knn.score(X_test, y_test))

plt.figure()
plt.xlabel('k')
plt.ylabel('accuracy')
plt.plot(k_range, acc_scores, marker='o')
plt.xticks([0, 5, 11, 15, 21])
plt.savefig('./output/k_validation.png')
plt.show()

 

你可能感兴趣的:(利用knn算法实现水果识别)