机器学习笔记1:KNN

k近邻算法

优点:

精度高、对异常值不敏感、无数据输入假定。
KNN理论简单,容易实现。

缺点:

计算复杂度高、空间复杂度高。

  • k近邻算法必须保存全部数据集,如果训练数据集的很大,必须使用大量的存储空间。此外, 由于必须对数据集中的每个数据计算距离值,实际使用时可能非常耗时。
  • k-近邻算法的另一个缺陷是它无法给出任何数据的基础结构信息,因此我们也无法知晓平均实例样本和典型实例样本具有什么特征。

适用数据范围:

数值型和标称型。

应用领域:

文本分类;模式识别;聚类分析;多分类领域。

工作原理

存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。

sklearn的kNN示例

from sklearn.neighbors import KNeighborsClassifier
import numpy as np

X_train = np.array([[3.393533211, 2.331273381],
                    [3.110073483, 1.781539638],
                    [1.343808831, 3.368360954],
                    [3.582294042, 4.679179110],
                    [2.280362439, 2.866990263],
                    [7.423436942, 4.696522875],
                    [5.745051997, 3.533989803],
                    [9.172168622, 2.511101045],
                    [7.792783481, 3.424088941],
                    [7.939820817, 0.791637231]
                    ])
Y_train = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
x = np.array([7, 3])
# 转化为一个二维数组
print (x.reshape(1, -1))

# 近邻的样本数n_neighbors
kNN_classifier = KNeighborsClassifier(n_neighbors=6) 
# 模型拟合数据
kNN_classifier.fit(X=X_train, y=Y_train) 

predict = kNN_classifier.predict(x.reshape(1, -1))
print (predict)

#整合
def KNN_classify(k, X_train, Y_train, x):
    """
    KNN分类器
    :param k: K近邻的样本数
    :param X_train: 训练数据的x
    :param Y_train: 训练数据的y
    :param x: 待预测数据x
    :return: 待预测数据的y
    """
    kNN_classifier = KNeighborsClassifier(n_neighbors=k)  # 创建类实例
    kNN_classifier.fit(X=X_train, y=Y_train)  # 模型拟合数据
    predict = kNN_classifier.predict(x)
    return predict[0]

print(KNN_classify(6, X_train, Y_train, x.reshape(1, -1)))

pyplot创建散点图

import numpy as np
import matplotlib.pyplot as plt

# 将数据可视化
plt.scatter(X_train[y_train==0,0],X_train[y_train==0,1], color='g', label = 'Tumor Size')
plt.scatter(X_train[y_train==1,0],X_train[y_train==1,1], color='r', label = 'Time')
plt.xlabel('Tumor Size')
plt.ylabel('Time')
plt.axis([0,10,0,5])
plt.show()
机器学习笔记1:KNN_第1张图片
image.png

你可能感兴趣的:(机器学习笔记1:KNN)