k-近邻算法 ---sklearn

k-近邻算法也称为knn算法, 可以解决分类问题, 也可以解决回归问题.

1. 算法原理

k-近邻算法的核心思想是未标记样本的类别, 由距离其最近的 \(k\) 个邻居投票来决定.

假设, 我们有一个已经标记的数据集, 即已经知道了数据集中每个样本所属的类别. 此时, 有一个未标记的数据样本, 我们的任务是预测出这个数据样本所属的类别. k-近邻算法的原理是, 计算待标记的数据样本和数据集中每个样本的距离, 取距离最近的 \(k\) 个样本. 待标记的数据样本所属的类别, 就由这 \(k\) 个距离最近的样本投票产生.

假设 X_test 为待标记的数据样本, X_train 为已标记的数据集, 算法原理的伪代码如下:

1.遍历 X_train 中的所有样本, 计算每个样本与 X_test 的距离, 并把距离保存在 Distance 数组中.

2.对 Distance 数组进行排序, 取距离最近的 \(k\) 个点, 记为 X_knn.

3.在 X_knn 中统计每个类别的个数, 即 class0 在 X_knn 中有几个样本, class1 在 X_knn 中有几个样本等. 

4.待标记样本的类别, 就是在 X_knn 中样本个数最多的那个类别.

1.1 算法优缺点

优点: 准确性高, 对异常值和噪声有较高的容忍度.

缺点: 计算量较大, 对内存的需求也较大. 从算法原理可以看出来, 每次对一个未标记的样本进行分类时,  都需要全部计算一遍距离.

1.2 算法参数

其算法参数是 \(k\) ,参数选择需要根据数据来决定.  \(k\) 值越大, 模型的偏差越大, 对噪声数据越不敏感, 当 \(k\) 值很大时, 可能造成模型欠拟合;  \(k\) 值越小, 模型的方差就会越大, 当 \(k\) 值太小, 就会造成模型过拟合.

1.3 算法的变种

k-近邻算法有一些变种, 其中之一就是可以增加邻居的权重. 默认情况下, 在计算距离时, 都是使用相同权重. 实际上, 我们可以针对不同的邻居指定不同的距离权重, 如距离越近权重越高. 

另一个变种是, 使用一定半径内的点取代距离最近的 \(k\) 个点. 在 scikit-learn 里, RadiusNeighborsClassifier 类实现了这个算法的变种. 当数据采用不均匀时, 该算法变种可以取得更好的性能.

2. 示例: 使用k-近邻算法进行分类

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets._samples_generator import make_blobs
from sklearn.neighbors import KNeighborsClassifier

centers = [[-2, 2], [2, 2], [0, 4]]
X, y = make_blobs(n_samples=60, centers=centers,
                 random_state=1, cluster_std=0.6)  # 产生60个样本

k = 5
clf =KNeighborsClassifier(n_neighbors=k)
clf.fit(X, y)

X_sample = [[1, 2]]
y_sample = clf.predict(X_sample)
print(y_sample)
neighbors = clf.kneighbors(X_sample, return_distance=False)
# 画出数据
plt.figure(figsize=(16, 10), dpi=144)
c = np.array(centers)
plt.scatter(X[:, 0], X[:, 1], c=y, s=100, cmap='cool')  # 画出样本
plt.scatter(c[:, 0], c[:, 1], s=100, marker='^', c='orange')  # 画出中心点
plt.scatter(X_sample[0][0], X_sample[0][1], marker='x', c=y_sample, s=100, cmap='hot')  # 代预测的点
for i in neighbors[0]:
    plt.plot([X[i][0], X_sample[0][0]], [X[i][1], X_sample[0][1]],
             'k--', linewidth=0.6)  # 预测点与距离最近的5个样本的连线
plt.show()

k-近邻算法 ---sklearn_第1张图片 预测的类别索引为[1], 根据上面设置的中心点, 预测点属于中心点为[2,2]的类别.


原作者: 黄勇昌

代码部分稍作修改, X_sample 一维改二维.

你可能感兴趣的:(机器学习,sklearn,机器学习,python)