K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。
KNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。
KNN的优点是:
1.简单,易于理解,易于实现,无需估计参数;
2. 适合对稀有事件进行分类;
3.特别适合于多分类问题(multi-modal,对象具有多个类别标签), kNN比SVM的表现要好。
缺点是:
当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数,而本应该属于小容量类的该样本会被误判为大容量类
解决办法是给K个邻居加上权值
from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
# 加载鸢尾花数据
iris = datasets.load_iris()
iris_X = iris.data
iris_y = iris.target
print('鸢尾花特征值:{}'.format(iris.feature_names))
print('鸢尾花类别:{}'.format(iris.target_names))
iris_X_train, iris_X_test, iris_y_train, iris_y_test = train_test_split(iris_X, iris_y, train_size=0.75, random_state=5)
# 调用sklearn的KNN算法
knn = KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=1, n_neighbors=5, p=2,
weights='uniform')
"""
K值,样本临近的k个点,默认5
n_neighbors : int, optional (default = 5)
Number of neighbors to use by default for :meth:`kneighbors` queries.
计算权值的算法
weights : str or callable, optional (default = 'uniform')
weight function used in prediction. Possible values:
- 'uniform' : uniform weights. All points in each neighborhood
are weighted equally.
- 'distance' : weight points by the inverse of their distance.
in this case, closer neighbors of a query point will have a
greater influence than neighbors which are further away.
- [callable] : a user-defined function which accepts an
array of distances, and returns an array of the same shape
containing the weights.
优化计算距离的算法
algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, optional
Algorithm used to compute the nearest neighbors:
- 'ball_tree' will use :class:`BallTree`
- 'kd_tree' will use :class:`KDTree`
- 'brute' will use a brute-force search.
- 'auto' will attempt to decide the most appropriate algorithm
based on the values passed to :meth:`fit` method.
Note: fitting on sparse input will override the setting of
this parameter, using brute force.
"""
knn.fit(iris_X_train, iris_y_train)
# KNeighborsClassifier()
iris_y_predict = knn.predict(iris_X_test)
score = knn.score(iris_X_test, iris_y_test)
# 生成预测报告
report = classification_report(iris_y_test, iris_y_predict)
print('iris_y_test:{}'.format(iris_y_test))
print('iris_y_predict:{}'.format(iris_y_predict))
print('评分:{}'.format(score))
print('预测报告:')
print(report)
新建knn.py文件
import numpy as np
class KNN(object):
def __init__(self, k=5):
self.x = None
self.y = None
self.k = k
def fit(self, x, y):
self.x = x
self.y = y
def predict(self, x_test):
predict_list = []
for x_t in x_test:
diff = self.x - x_t
distances = np.sum(np.square(diff), axis=1)**0.5
sorted_dis_index = np.argsort(distances)
# 关于argsort函数的用法
# argsort函数返回的是数组值从小到大的索引值
# >>> x = np.array([3, 1, 2])
# >>> np.argsort(x)
# array([1, 2, 0])
class_count = {} # 定义一个字典
# 选择k个最近邻
for i in range(self.k):
vote_label = self.y[sorted_dis_index[i]]
# 计算k个最近邻中各类别出现的次数
class_count[vote_label] = class_count.get(vote_label, 0) + 1
# 找出出现次数最多的类别标签并返回对应下标
max_count = list(class_count.items())[0][1]
max_index = list(class_count.items())[0][0]
for key, value in class_count.items():
if value > max_count:
max_count = value
max_index = key
predict_list.append(max_index)
return np.array(predict_list)
def score(self, y_t, y_p):
count = 0
for i, j in zip(y_p, y_t):
if i == j:
count += 1
return count / len(y_t)
训练测试
from sklearn import datasets
from sklearn.model_selection import train_test_split
import knn
# 加载数据
iris = datasets.load_iris()
iris_X = iris.data
iris_y = iris.target
print('鸢尾花特征值:{}'.format(iris.feature_names))
print('鸢尾花类别:{}'.format(iris.target_names))
# 切分训练集和测试集
iris_X_train, iris_X_test, iris_y_train, iris_y_test = train_test_split(iris_X, iris_y, train_size=0.75, random_state=5)
kn = knn.KNN()
# 训练
kn.fit(iris_X_train, iris_y_train)
# 预测
iris_y_predict = kn.predict(iris_X_test)
# 求评分
score = kn.score(iris_y_test, iris_y_predict)
print(iris_y_test)
print(iris_y_predict)
print('评分:{}'.format(score))