【鸢尾花的k值调优】

鸢尾花数据集介绍

鸢尾花数据集包含了三个类别的鸢尾花样本:Setosa、Versicolor和Virginica。每个样本有四个特征:花萼长度、花萼宽度、花瓣长度和花瓣宽度。我们的目标是通过这些特征来预测鸢尾花的类别。

K近邻算法简介

K近邻算法是一种简单而有效的分类算法。它的基本思想是:对于一个未知样本,通过计算其与训练集中所有样本的距离,选取距离最近的k个样本,然后根据这k个样本的类别来预测未知样本的类别。

超参数搜索

在K近邻算法中,k值是一个重要的超参数。不同的k值可能会导致模型性能的显著变化。因此,我们需要通过超参数搜索来找到最优的k值。这里我们将使用网格搜索(Grid Search)来进行超参数搜索。

网格搜索

网格搜索是一种通过遍历给定超参数的所有可能组合来确定最优超参数的方法。为了进行网格搜索,我们首先要定义搜索的k值范围,然后在这个范围内尝试所有可能的k值,然后评估每个k值对应的模型性能,并选择性能最优的k值。

示例代码

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.neighbors import KNeighborsClassifier

# 加载数据集
iris = load_iris()
X, y = iris.data, iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 定义要搜索的k值范围
param_grid = {'n_neighbors': [1, 3, 5, 7, 9, 11, 13, 15]}

# 创建K近邻分类器
knn = KNeighborsClassifier()

# 初始化网格搜索对象
grid_search = GridSearchCV(knn, param_grid, cv=5)

# 在训练集上进行网格搜索
grid_search.fit(X_train, y_train)

# 方式1直接比较预测值和真实值
y_pred = grid_search.predict(X_test)
print(y_pred == y_test)
print("准确率:", sum(y_pred == y_test) / len(y_test))

# 方式2计算在测试集上的准确率
score = grid_search.score(X_test, y_test)
print("准确率:", score)

# 输出最优的k值和对应的准确率
print("最优的k值:", grid_search.best_params_['n_neighbors'])
print("最优的准确率:", grid_search.best_score_)
print("最优的模型:", grid_search.best_estimator_)

在这个示例代码中,我们使用了sklearn库中的GridSearchCV类进行网格搜索。我们指定了要搜索的k值范围为[1, 3, 5, 7, 9, 11, 13, 15],然后网格搜索会尝试这些k值,并返回最优的k值和对应的准确率。

你可能感兴趣的:(数学建模,机器学习,算法,人工智能)