K-近邻算法

k-近邻算法原理


k-近邻算法采用测量不同特征值之间的距离方法进行分类

  • 优点:精度高,对异常值不敏感、无数据输入假定。
  • 缺点:时间复杂度高、空间复杂度高。
  • 使用数据范围:数值型和标称型。

1.工作原理


存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据 与所属分类的对应关系。输人没有标签的新数据后,将新数据的每个特征与样本集中数据对应的 特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们 只选择样本数据集中前K个最相似的数据,这就是K-近邻算法中K的出处,通常K是不大于20的整数。 最后 ,选择K个最相似数据中出现次数最多的分类,作为新数据的分类。

2.在scikit-learn库中使用k-近邻算法


  • 分类问题:from sklearn.neighbors import KNeighborsClassifier
  • 回归问题:from sklearn.neighbors import KNeighborsRegressor

实例


  • 用于分类
    使用knn算法,对鸢尾花数据进行分类
    1.导包鸢尾花数据:
    from sklearn.datasets import load_iris
    2.获取训练样本
    iris = load_iris()
    data = iris.data
    target = iris.target




    3.绘制出其中两个特征的散点图
    plt.scatter(data[:,0], data[:,1], c=target, cmap='rainbow')



    3.定义KNN分类器
    knn = KNeighborsClassifier()
    • 第一步训练数据
      knn.fit(data[:,0:1], target)
      从训练数据中分割出预测数据
      from sklearn.model_selection import train_test_split
      X_train, X_test, y_train, y_test = train_test_split(data[:,0:1],target, test_size=50)
      y_ = knn.predict(X_test)
      y_test
      plt.plot(np.arange(50),y_, np.arange(50), y_test)


    • 第二步预测数据:,所预测的数据,自己创造,就是上面所显示图片的背景点
      生成预测数据
      取范围
      xmin, xmax = data[:,0].min(), data[:,0].max()
      ymin, ymax = data[:,1].min(), data[:,1].max()
      生成面
      x = np.linspace(xmin, xmax, 1000)
      y = np.linspace(ymin, ymax, 1000)
      X,Y = np.meshgrid(x,y)
      X_test = np.c_[X.ravel(), Y.ravel()]

      data = data[:, 0:2]
      knn = KNeighborsClassifier()
      knn.fit(data, target)

      y_ = knn.predict(X_test)
      pcolormesh快速画图
      plt.pcolormesh(X,Y, y_.reshape((1000,1000)))
      plt.scatter(data[:,0], data[:,1], c=target, cmap='rainbow')

你可能感兴趣的:(K-近邻算法)