k-近邻算法(KNN):鸢尾花种类预测

文章目录

  • 一、原理
  • 二、鸢尾花种类预测
    • 1、获取数据
    • 2、数据集划分
    • 3、特征工程(标准化)
    • 4、KNN预估器流程
    • 5、模型评估
    • 6、完整代码
  • 三、K-近邻算法总结
  • 四、模型选择与调优

一、原理

如果一个样本在特征空间中 k个最相似(即特征空间中最邻近) 的样本中大多数属于某个类别,则该样本也属于这个类别

二、鸢尾花种类预测

1、获取数据

导sklearn原有的数据集

from sklearn.datasets import load_iris
   # 1、获取数据
    iris = load_iris()

2、数据集划分

# 2、划分数据集
from sklearn.model_selection import train_test_split
# 划分数据集(训练集、测试集 x:特征值 y:目标值)
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=6)

3、特征工程(标准化)

# 3、特征工程 标准化
transfer = StandardScaler()
# fit_transform 计算+转换;transform 只计算
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)

4、KNN预估器流程

 # 4、KNN算法预估器
    estimator = KNeighborsClassifier(n_neighbors=3)
    estimator.fit(x_train, y_train)

5、模型评估

# 5、模型评估
# 方法一:直接比对真实值和预测值
y_predict = estimator.predict(x_test)
print("y_predict:\n", y_predict)
print("直接比对真实值和预测值", y_test == y_predict)
# 方法二:计算准确率
score = estimator.score(x_test, y_test)
print("准确率:\n", score)

6、完整代码

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier


def knn_irs():
    """
    knn算法对鸢尾花分类
    :return:
    """
    # 1、获取数据
    iris = load_iris()
    # 2、划分数据集(训练集、测试集 x:特征值 y:目标值)
    x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=28)
    # 3、特征工程 标准化
    transfer = StandardScaler()
    # fit_transform 计算+转换;transform 只计算
    x_train = transfer.fit_transform(x_train)
    x_test = transfer.transform(x_test)
    # 4、KNN算法预估器
    estimator = KNeighborsClassifier(n_neighbors=3)
    estimator.fit(x_train, y_train)
    # 5、模型评估
    # 方法一:直接比对真实值和预测值
    y_predict = estimator.predict(x_test)
    print("y_predict:\n", y_predict)
    print("直接比对真实值和预测值", y_test == y_predict)
    # 方法二:计算准确率
    score = estimator.score(x_test, y_test)
    print("准确率:\n", score)
    return None


if __name__ == '__main__':
    knn_irs()

三、K-近邻算法总结

优点:简单、易于理解、易于实现、无需训练
缺点:
1、必须指定K值,K值选择不当,则分类精度不能保障
2、懒惰算法,计算量大,开销大

四、模型选择与调优

交叉验证:将数据分为训练和验证集
例:将数据分成4份,其中一份作为验证集,然后经过4次测试,每次都更换不同的验证集。即得到四组模型的结果,取平均值作为最终结果。

使用交叉验证改进鸢尾花预测

  # 加入网格搜索和交叉验证
    param_dict = {"n_neighbors": [1, 3, 5, 7, 9, 11]}
    estimator = GridSearchCV(estimator, param_grid=param_dict, cv=10)
    estimator.fit(x_train, y_train)
    # 5、模型评估
    # 方法一:直接比对真实值和预测值
    y_predict = estimator.predict(x_test)
    print("y_predict:\n", y_predict)
    print("直接比对真实值和预测值", y_test == y_predict)
    # 方法二:计算准确率
    score = estimator.score(x_test, y_test)
    print("准确率:\n", score)
    print("最佳参数:\n", estimator.best_params_)
    print("最佳结果:\n", estimator.best_score_)
    print("最佳估计器:\n", estimator.best_estimator_)
    print("交叉验证结果:\n", estimator.cv_results_)
y_predict:
 [0 1 1 0 2 1 2 1 1 0 2 0 1 1 2 0 2 2 2 1 0 0 1 2 1 0 2 2 0 1 0 2 1 0 2 1 2
 1]
直接比对真实值和预测值 [ True False  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True False
  True  True  True  True  True  True  True  True  True  True False  True
  True  True]
准确率:
 0.9210526315789473
最佳参数:
 {'n_neighbors': 3}
最佳结果:
 0.9553030303030303
最佳估计器:
 KNeighborsClassifier(n_neighbors=3)
交叉验证结果:
 {'mean_fit_time': array([0.00030003, 0.00030031, 0.00040004, 0.00030005, 0.00030005,
       0.00040019]), 'std_fit_time': array([0.0004583 , 0.00045873, 0.00048995, 0.00045833, 0.00045833,
       0.00049013]), 'mean_score_time': array([0.00060017, 0.00089998, 0.0006    , 0.00080037, 0.00069995,
       0.00070009]), 'std_score_time': array([0.00049004, 0.00029999, 0.0004899 , 0.00040019, 0.00045822,
       0.00045832]), 'param_n_neighbors': masked_array(data=[1, 3, 5, 7, 9, 11],
             mask=[False, False, False, False, False, False],
       fill_value='?',
            dtype=object), 'params': [{'n_neighbors': 1}, {'n_neighbors': 3}, {'n_neighbors': 5}, {'n_neighbors': 7}, {'n_neighbors': 9}, {'n_neighbors': 11}], 'split0_test_score': array([0.83333333, 0.91666667, 0.91666667, 0.91666667, 0.91666667,
       0.91666667]), 'split1_test_score': array([1., 1., 1., 1., 1., 1.]), 'split2_test_score': array([0.90909091, 0.90909091, 0.90909091, 0.90909091, 0.90909091,
       0.90909091]), 'split3_test_score': array([1.        , 1.        , 0.90909091, 0.90909091, 0.90909091,
       0.90909091]), 'split4_test_score': array([0.72727273, 0.72727273, 0.72727273, 0.72727273, 0.72727273,
       0.72727273]), 'split5_test_score': array([1., 1., 1., 1., 1., 1.]), 'split6_test_score': array([1.        , 1.        , 0.81818182, 0.81818182, 1.        ,
       1.        ]), 'split7_test_score': array([1.        , 1.        , 1.        , 0.90909091, 1.        ,
       1.        ]), 'split8_test_score': array([0.90909091, 1.        , 1.        , 1.        , 1.        ,
       1.        ]), 'split9_test_score': array([1., 1., 1., 1., 1., 1.]), 'mean_test_score': array([0.93787879, 0.95530303, 0.9280303 , 0.91893939, 0.94621212,
       0.94621212]), 'std_test_score': array([0.0894966 , 0.08343314, 0.08894662, 0.08571326, 0.08301938,
       0.08301938]), 'rank_test_score': array([4, 1, 5, 6, 2, 2])}

Process finished with exit code 0

你可能感兴趣的:(人工智能+大数据,近邻算法,python,机器学习)