如果一个样本在特征空间中的K个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别。
注意:若K = 1,分类结果容易受到异常值的干扰
对数据进行距离计算之前应该先进行标准化处理,不然数据之间数值差距太大。
(metric = 'minkowski')
(p = 1,metric = 'minkowski')
(p = 2,metric = 'minkowski')
K值过大,容易受到样本不均衡的影响;K值过小,容易受到异常值影响;
K值的选择影响分类的结果。
sklearn.neighbor.KNeighborsClassifier(n_neighbors=5, algorithm=‘auto’)
from sklearn.datasets import load_iris #导入鸢尾花数据集模块
from sklearn.model_selection import train_test_split #导入数据集划分模块
from sklearn.preprocessing import StandardScaler #导入标准化模块
from sklearn.neighbors import KNeighborsClassifier #导入KNN算法模块
def knn_iris():
# 获取数据集
iris = load_iris()
# 划分数据集
x_train,x_test,y_train,y_test = train_test_split(iris.data,iris.target,random_state=6) #随机数种子
# 特征工程 - 标准化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train) # 对训练集进行标准化
x_test = transfer.transform(x_test)
# KNN算法模型
estimator = KNeighborsClassifier(n_neighbors=3)
estimator.fit(x_train,y_train) #模型
# 模型评估
#方法一:直接比对
y_predict = estimator.predict(x_test)
print("直接比对",y_predict == y_test)
#方法二:计算准确率
score = estimator.score(x_test,y_test)
print("准确率为:",score)
return None
knn_iris()
运行结果:
[ True True True True True True False True True True True True
True True True False True True True True True True True True
True True True True True True True True True True False True
True True]
准确率为: 0.9210526315789473
为了让模型更加准确,可以将拿到的训练集数据,再分为训练和验证集
以下图为例,将训练集数据分成4份,其中一份作为验证集。然后经过4次的测试,每次都更换不同的验证集。即得到4组模型的结果,取平均值作为最终结果,又称4折交叉验证。
通常情况下,很多参数是需要手动指定的(如K-近邻算法中的K值),这种叫超参数。
但是手动指定过程繁杂,所以需要对模型预设几种超参数组合,即使用网格搜索来对预先指定的几个超参数的取值进行遍历测试比较。
每组超参数都采用交叉验证来进行评估,最后选出最优参数组合建立模型
sklearn.model_selection.GridSearchCV(estimator, param_grid=None, cv=None)
对估计器的指定参数进行详尽搜索
最佳参数:estimator.best_params_
最佳结果:estimator.best_score_
最佳估计器:estimator.best_estimator_
交叉验证结果:estimator.cv_results_
from sklearn.datasets import load_iris #导入鸢尾花数据集模块
from sklearn.model_selection import train_test_split #导入数据集划分模块
from sklearn.preprocessing import StandardScaler #导入标准化模块
from sklearn.neighbors import KNeighborsClassifier #导入KNN算法模块
from sklearn.model_selection import GridSearchCV #导入模型调优模块
def knn_iris_gscv(): #使用KNN对鸢尾花进行分类,包含网格搜索和交叉验证
# 获取数据集
iris = load_iris()
# 划分数据集
x_train,x_test,y_train,y_test = train_test_split(iris.data,iris.target,random_state=6) #随机数种子
# 标准化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train) # 对训练集进行标准化
x_test = transfer.transform(x_test)
# KNN算法预估器
estimator = KNeighborsClassifier() #模型创建
# 预估器生成之后增添网格搜索与交叉验证
# 参数准备
param_dict = {'n_neighbors':[1,3,5,7]}
estimator = GridSearchCV(estimator, param_grid=param_dict, cv=10) #模型调优
estimator.fit(x_train,y_train) #模型调用
# 模型评估
#方法一:直接比对
y_predict = estimator.predict(x_test)
print("直接比对",y_predict == y_test)
#方法二:计算准确率
score = estimator.score(x_test,y_test)
print("准确率为:",score)
print("最佳参数:\n",estimator.best_params_)
print("最佳结果:\n", estimator.best_score_)
print("最佳估计器:\n", estimator.best_estimator_)
print("交叉验证结果:\n", estimator.cv_results_)
return None
knn_iris_gscv()