KNN算法

一、KNN算法简介

K最近邻(KNN,K-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是K个最近的邻居的意思,说的是每个样本都可以用它最接近的K个邻近值来代表。 K最近邻算法就是将数据集合中每一个记录进行分类的方法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一。
优点:

  • 所选择的邻居都是已经正确分类的对象
  • KNN算法本身比较简单,分类器不需要使用训练集进行训练,训练时间复杂度为0。本算法分类的复杂度与训练集中数据的个数成正比。
  • 对于类域的交叉或重叠较多的待分类样本,KNN算法比其他方法跟合适。

缺点:

  • 当样本分布不平衡时,很难做到正确分类
  • 计算量较大,因为每次都要计算测试数据到全部数据的距离。

二、算法原理

KNN算法_第1张图片
如上图所示,图中的数据可以分为蓝色正方形和红色三角形两类,图中心的绿色圆点是待分类数据,下面我们通过K最近领法对绿色圆点进行分类:
1.当 k = 3 k=3 k=3时,由图中实线圆内的数据可知:绿色圆点最近领的三个邻居中,一共有一个蓝色正方形和两个红色三角形,那么就可以将绿色圆点和红色三角形判定为一类。
2.当 k = 5 k=5 k=5时,由图中虚线圆内的数据可知:绿色圆点最近领的五个邻居中,一共有三个蓝色正方形和两个红色三角形,那么就可以将绿色圆点和蓝色正方形判定为一类。
至此,我们对KNN算法已经有了大概的了解。

三、算法步骤

1.初始化数据集

初始化训练集和测试集。训练集一般为两类或者多种类别的数据;测试集一般为一个数据。

2.计算距离

计算测试数据到其他所有数据的距离,并记录下来。
常用到的距离计算公式如下:
①欧几里得距离(欧氏距离): d = ( x 1 − x 2 ) 2 + ( y 1 − y 2 ) 2 d=\sqrt{(x_{1}-x_{2})^{2}+(y_{1}-y_{2})^{2}} d=(x1x2)2+(y1y2)2
②曼哈顿距离
③闵可夫斯基距离
④切比雪夫距离
⑤马氏距离
⑥余弦相似度
⑦皮尔逊相关系数
⑧汉明距离
⑨杰卡德相似系数
⑩编辑距离
⑪DTW 距离
⑫KL 散度

3.寻找最近邻数据

将所有距离进行升序排序,确定K值,最近的K个邻居即距离最短的K个数据。
关于K值得选择问题:

  • K 值的选择会对算法的结果产生重大影响。
  • K值较小意味着只有与测试数据较近的训练实例才会对预测结果起作用,容易发生过拟合。
  • 如果 K 值较大,优点是可以减少学习的估计误差,但缺点是学习的近似误差增大,这时与测试数据较远的训练实例也会对预测起作用,使预测发生错误。
  • 在实际应用中,K 值一般选择一个较小的数值,通常采用交叉验证的方法来选择最优的 K 值。随着训练实例数目趋向于无穷和 K=1 时,误差率不会超过贝叶斯误差率的2倍,如果K也趋向于无穷,则误差率趋向于贝叶斯误差率。(贝叶斯误差可以理解为最小误差)

三种交叉验证方法

  • Hold-Out: 随机从最初的样本中选出部分,形成交叉验证数据,而剩余的就当做训练数据。 一般来说,少于原本样本三分之一的数据被选做验证数据。常识来说,Holdout 验证并非一种交叉验证,因为数据并没有交叉使用。
  • K-foldcross-validation:K折交叉验证,初始采样分割成K个子样本,一个单独的子样本被保留作为验证模型的数据,其他K-1个样本用来训练。交叉验证重复K次,每个子样本验证一次,平均K次的结果或者使用其它结合方式,最终得到一个单一估测。这个方法的优势在于,同时重复运用随机产生的子样本进行训练和验证,每次的结果验证一次,10折交叉验证是最常用的。
  • Leave-One-Out Cross Validation:正如名称所建议, 留一验证(LOOCV)意指只使用原本样本中的一项来当做验证资料, 而剩余的则留下来当做训练资料。 这个步骤一直持续到每个样本都被当做一次验证资料。 事实上,这等同于 K-fold 交叉验证是一样的,其中K为原本样本个数。

4.决策分类

明确K个邻居中所有数据类别的个数,将测试数据划分给个数最多的那一类。即由输入实例的 K 个最临近的训练实例中的多数类决定输入实例的类别。
最常用的两种决策规则:

  • 多数表决法:多数表决法和我们日常生活中的投票表决是一样的,少数服从多数,是最常用的一种方法。
  • 加权表决法:有些情况下会使用到加权表决法,比如投票的时候裁判投票的权重更大,而一般人的权重较小。所以在数据之间有权重的情况下,一般采用加权表决法。

图示说明(其中K=4):
KNN算法_第2张图片

四、python代码实现

from operator import attrgetter
import matplotlib.pyplot as plt
import matplotlib
from math import sqrt


class point:
    def __init__(self, kind, dis):
        self.kind = kind
        self.dis = dis

#####初始化数据集#####
data_A = [[1,2],[3.2,4],[4,7],[5.2,3],[7,4.1]]#数据集A
data_B = [[2.2,5.5],[4.2,2],[5,5],[6.3,7]]#数据集B
test_data = [[4.5,4.5], [1, 2]]#测试集
num_A = len(data_A)
num_B = len(data_B)
num_T = len(test_data)

def getDis(p1, p2):
    return sqrt(pow(p1[0] - p2[0], 2) + pow(p1[1] - p2[1], 2))

# calc dist
def calcDist():
    lsDist = [[], []]
    ls = []
    for i in range(len(test_data)):
        for j in range(num_A):
            lsDist[i].append(point(0, getDis(data_A[j], test_data[i])))
            ls.append(getDis(data_A[j], test_data[i]))
        for j in range(num_B):
            lsDist[i].append(point(1, getDis(data_B[j], test_data[i])))
            ls.append(getDis(data_B[j], test_data[i]))

    return lsDist, ls

def judge(k, lsDist):
    num0 = 0
    num1 = 0
    for j in range(len(test_data)):
        for i in range(k):
            if(lsDist[j][i].kind == 0):
                num0 += 1
            else:
                num1 += 1
        if(num0 > num1):
            print('A类')
        else:
            print('B类')

def draw():
    matplotlib.rcParams['font.sans-serif'] = ['SimHei']
    for i in range(num_A-1):
        plt.plot(data_A[i][0], data_A[i][1], 'r^')
    plt.plot(data_A[num_A-1][0], data_A[num_A-1][1], 'r^', label='A')
    for i in range(num_B-1):
        plt.plot(data_B[i][0], data_B[i][1], 'bo')
    plt.plot(data_B[num_B-1][0], data_A[num_B-1][1], 'bo', label='B')
    for i in range(num_T-1):
        plt.plot(test_data[i][0], test_data[i][1], 'k+')
    plt.plot(test_data[num_T-1][0], test_data[num_T-1][1], 'k+', label = '未标识')
    plt.legend()
    plt.xlim(0, 10)
    plt.ylim(0, 10)
    plt.show()

def printf(lsDist):
    for j in range(len(lsDist)):
        for i in range(len(lsDist[j])):
            print('({},{})'.format(lsDist[j][i].dis, lsDist[j][i].kind))

if __name__ == '__main__':
    lsDist, ls = calcDist()
    print('距离列表')
    printf(lsDist)
    for i in range(len(lsDist)):
        lsDist[i] = sorted(lsDist[i], key=attrgetter('dis'))
    print('排序后的距离列表')
    printf(lsDist)
    k = int(input('请输入k:'))
    judge(k, lsDist)
    draw()

图形:
KNN算法_第3张图片

你可能感兴趣的:(机器学习,算法,python)