代码可在Github上下载:代码下载
k近邻可以算是机器学习中易于理解、实现的一个算法了,《机器学习实战》的第一章便是以它作为介绍来入门。而k近邻的算法可以简述为通过遍历数据集的每个样本进行距离测量,并找出距离最小的k个点。但是这样一来一旦样本数目庞大的时候,就容易造成大量的计算。
所以需要将数据用树形结构存储,以便快速检索,这也就是本文要阐述的kd树。
分为两部分,一个是kd树建立,一个是kd树的搜索。
# --*-- coding:utf-8 --*--
import numpy as np
先定义一下字符集还有包。
首先我们先实现一个结点类,用来表示kd。
class Node:
def __init__(self, data, lchild = None, rchild = None):
self.data = data
self.lchild = lchild
self.rchild = rchild
一个结点包含着结点域,左孩子,右孩子。(如果不熟二叉树的话建议先看一些数据结构二叉树的相关知识,以及先序遍历,中序遍历还有后序遍历的相关代码)
二叉树相关代码(C语言实现)
然后是创建kd树的代码,主要根据P41,算法3.2来实现的。
def create(self, dataSet, depth): #创建kd树,返回根结点
if (len(dataSet) > 0):
m, n = np.shape(dataSet) #求出样本行,列
midIndex = m / 2 #中间数的索引位置
axis = depth % n #判断以哪个轴划分数据,对应书中算法3.2(2)公式j()
sortedDataSet = self.sort(dataSet, axis) #进行排序
node = Node(sortedDataSet[midIndex]) #将节点数据域设置为中位数,具体参考下书本
# print sortedDataSet[midIndex]
leftDataSet = sortedDataSet[: midIndex] #将中位数的左边创建2个副本
rightDataSet = sortedDataSet[midIndex+1 :]
print leftDataSet
print rightDataSet
node.lchild = self.create(leftDataSet, depth+1) #将中位数左边样本传入来递归创建树
node.rchild = self.create(rightDataSet, depth+1)
return node
else:
return None
以上的代码通过看注释应该可以了解一二,其中需要按轴j(mod k)+1,也就是【depth(深度) mod n(特征数)+1】为轴划分中位数,然后决定插入数据到左结点,右结点。然后注意一下为什么上面的按轴划分的公式是【depth(深度) mod n(特征数)】,这是因为python的数组下标是从0开始的。
def sort(self, dataSet, axis): #采用冒泡排序,利用aixs作为轴进行划分
sortDataSet = dataSet[:] #由于不能破坏原样本,此处建立一个副本
m, n = np.shape(sortDataSet)
for i in range(m):
for j in range(0, m - i - 1):
if (sortDataSet[j][axis] > sortDataSet[j+1][axis]):
temp = sortDataSet[j]
sortDataSet[j] = sortDataSet[j+1]
sortDataSet[j+1] = temp
print sortDataSet
return sortDataSet
创建树的时候为了找中位数,需要按轴(某一维度)排序,找出中间那个数。这里我用了冒泡排序。
def preOrder(self, node):
if node != None:
print "tttt->%s" % node.data
self.preOrder(node.lchild)
self.preOrder(node.rchild)
当然我选择了先序遍历来简单检查下树的创建有没有问题。(看下这棵树能否正常遍历,这步可忽略)
def search(self, tree, x): #搜索
self.nearestPoint = None #保存最近的点
self.nearestValue = 0 #保存最近的值
def travel(node, depth = 0): #递归搜索
if node != None: #递归终止条件
n = len(x) #特征数
axis = depth % n #计算轴
if x[axis] < node.data[axis]: #如果数据小于结点,则往左结点找
travel(node.lchild, depth+1)
else:
travel(node.rchild, depth+1)
#以下是递归完毕,对应算法3.3(3)
distNodeAndX = self.dist(x, node.data) #目标和节点的距离判断
if (self.nearestPoint == None): #确定当前点,更新最近的点和最近的值,对应算法3.3(3)(a)
self.nearestPoint = node.data
self.nearestValue = distNodeAndX
elif (self.nearestValue > distNodeAndX):
self.nearestPoint = node.data
self.nearestValue = distNodeAndX
print(node.data, depth, self.nearestValue, node.data[axis], x[axis])
if (abs(x[axis] - node.data[axis]) <= self.nearestValue): #确定是否需要去子节点的区域去找(圆的判断),对应算法3.3(3)(b)
if x[axis] < node.data[axis]:
travel(node.rchild, depth+1)
else:
travel(node.lchild, depth + 1)
travel(tree)
return self.nearestPoint
def dist(self, x1, x2): #欧式距离的计算
return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5
搜索树的时候比较麻烦,首先先说下原理吧。
(1) 在kd树中找出包含目标点x的叶结点:从根结点出发,递归的向下访问kd树。若目标点当前维的坐标值小于切分点的坐标值,则移动到左子结点,否则移动到右子结点。直到子结点为叶结点为止;
(2) 以此叶结点为“当前最近点”;
(3) 递归的向上回退,在每个结点进行以下操作:
(a) 如果该结点保存的实例点比当前最近点距目标点更近,则以该实例点为“当前最近点”;
(b) 当前最近点一定存在于该结点一个子结点对应的区域。检查该子结点的父结点的另一个子结点对应的区域是否有更近的点。具体的,检查另一个子结点对应的区域是否与以目标点为球心、以目标点与“当前最近点”间的距离为半径的超球体相交。如果相交,可能在另一个子结点对应的区域内存在距离目标更近的点,移动到另一个子结点。接着,递归的进行最近邻搜索。如果不相交,向上回退。
(4) 当回退到根结点时,搜索结束。最后的“当前最近点”即为x的最近邻点。
注意了,先按步骤找到叶结点,然后回朔的时候要做两件事,(a)是更新最新点,(b)是检查是否需要检查父结节点的另外一个结点的区域。
if x[axis] < node.data[axis]: #如果数据小于结点,则往左结点找
travel(node.lchild, depth+1)
else:
travel(node.rchild, depth+1)
这段是类似于二叉查找树的过程,直至查找到叶子节点。
#以下是递归完毕后,往父结点方向回朔,对应算法3.3(3)
distNodeAndX = self.dist(x, node.data) #目标和节点的距离判断
if (self.nearestPoint == None): #确定当前点,更新最近的点和最近的值,对应算法3.3(3)(a)
self.nearestPoint = node.data
self.nearestValue = distNodeAndX
elif (self.nearestValue > distNodeAndX):
self.nearestPoint = node.data
self.nearestValue = distNodeAndX
print(node.data, depth, self.nearestValue, node.data[axis], x[axis])
if (abs(x[axis] - node.data[axis]) <= self.nearestValue): #确定是否需要去子节点的区域去找(圆的判断),对应算法3.3(3)(b)
if x[axis] < node.data[axis]:
travel(node.rchild, depth+1)
else:
travel(node.lchild, depth + 1)
这段代码,就是P43算法3.3(3)中的内容。
(a)容易实现,但是(b)的原理是判断目标点和最近的一个点的距离为半径画一个圆(就如书本P44图3.5,目标点S和当前最近点D形成了一个圆),是否跟父结点按轴分的那条线(也就是圆内的那条直线)有交集。
说白了,就是公式:|目标值(按轴读值) - 父节点(按轴读值)| < 最近的值(圆的半径),这里按轴读取就是P44图3.5中的x的y轴的值,然后减去相交的那条直线y轴的值,看是否小于半径。
注意:评论里有说这里的node.data不知道是指示哪个结点。这里要说明的是,这个node并不是父节点,而是当前结点。这里如果你对数据结构的二叉树不太熟的话,是不太容易get到这个点的。我只能稍微说下。
“这里应该了解下二叉查找树的过程”
如果找到了的话,把另一结点重新递归一次就好了。对应以下代码:
travel(node.rchild, depth+1)
最后在github贴出全部代码(如果方便的话麻烦给个赞吧,您的支持就是我前进的动力),然后来运行一下代码(这段代码在python3.5下成功运行)。
KNN(KDtree)代码下载
结果输出(5,4)