关于KD树的构建和查找的理论可参考《统计学习方法》第三章以及这篇博文https://zhuanlan.zhihu.com/p/23966698,
import os
import csv
import numpy as np
import string
import pandas as pd
import operator
import re as re
import time
import datetime
class KD_node():
def __init__(self, vector, label, depth, dimension, l_node = None, r_node = None):
self.val = vector
self.label = label
self.depth = depth
self.dimension = dimension
self.l_node = l_node
self.r_node = r_node
def add_lnode(self, new_node):
self.l_node = new_node
def add_rnode(self, new_node):
self.r_node = new_node
def print_preorder(self):
print self.val
if self.l_node:
self.l_node.print_preorder()
if self.r_node:
self.r_node.print_preorder()
def search(self, fea_vec, L, k):
if self.val[self.dimension] < fea_vec[self.dimension]:#目标点在当前结点左边
if self.l_node != None:
L = self.l_node.search(fea_vec, L, k) #若还有左结点则搜索左结点
L = self.insertL(fea_vec, L, k) #对当前结点执行插入操作(包含了是否插入条件检测)
if len(L) < k: #若L仍然未满,则不用考虑,直接搜索右子树
L = self.r_node.search(fea_vec, L, k)
else if (self.val[self.dimension] - fea_vec[self.dimension])**2 < L[-1][1]: #判断右子树是否可能存在能插入L的结点,判断方法是求目标点到切割超平面的距离,此距离也就是切割维度上,切割点与目标点的距离,并将此距离与L内最大距离比较,若小于L内最大距离,则说明在切割超平面的另一边仍有可能存在能插入L的结点,我认为这一步,这个判断条件也是KD树提高搜索效率的关键所在
L = self.r_node.search(fea_vec, L, k)
else: #目标点在当前结点右边,跟上边的情况是对称的,就不写那么多注释了
if self.r_node != None:
L = self.r_node.search(fea_vec, L, k)
L = self.insertL(fea_vec, L, k)
if len(L) < k:
L = self.l_node.search(fea_vec, L, k)
else if (self.val[self.dimension] - fea_vec[self.dimension])**2 < L[-1][1]:
L = self.l_node.search(fea_vec, L, k)
return L
def insertL(self, fea_vec, L, k): #插入操作,检查当前结点是否可以插入L,若可以则插入,否则不做操作
distance = sum((fea_vec - self.val)**2)
if len(L) < k: #如果L还没满,则直接插入
L.append([self, distance])
L = L[L[:, 1].argsort()] #插入后排序,对Python不太熟,急于实现算法,用的效率比较低的方式,见谅。。。
else if distance < L[-1][1]: #若L已满,判断L中最大距离是否比当前结点距离更大
L[-1] = [self, distance]
L = L[L[:, 1].argsort()]
return L
def construct(data, depth):
if len(data) == 0:
return None
dimension_sum = np.shape(data)[1] - 1 #特征维度,减1是因为label也占了一列
dimension = depth % dimension_sum + 1 #切割轴维度
data = data[data[:, dimension].argsort()]
node_data = data[len(data)/2]
new_node = KD_node(node_data[1:], node_data[0], depth, dimension)
new_node.l_node = construct(data[:len(data)/2], depth + 1)
new_node.r_node = construct(data[len(data)/2 + 1:], depth + 1)
return new_node
class KD_Tree(): #与一般的二叉排序树结构差不多,区别在于KD树需要反复使用各个维度来比较以构造二叉树
# def __init__(self, data_root):
# data_file = pd.read_csv(data_root, header = None)
# self.data = np.array(data_file)[:, 2:]
# self.label = np.array(data_file)[:, 1]
def __init__(self, data):
self.data = data[:,:2]
self.label = data[:, 2]
def constructor(self):
data = np.c_[self.label, self.data] #合并特征向量和label
data = data[data[:, 1].argsort()] #按第一维排序(第0维是label)
self.root = construct(data, 0)
def print_preorder(self):
print 'preorder'
self.root.print_preorder()
def predict(self, fea_vec, k = 3):
L = [] #存储当前检测到的距离最近的K个点,初始为空,L中的元素为[KD_node, distance]
self.root(fea_vec, L, k)
label_dict = {}
for each_sample in L:
label_dict[each_sample.label] += 1
print(sorted(label_dict,key=lambda x:label_dict[x])[-1]) #打印出k个最近邻结点中个数最多的类别
代码还没有调试,思路应该没错,明天调好再更新一下。。。。