参考百度百科http://baike.baidu.com/link?url=JLBeRUhL6WLyp8R6TAFDD8swLfazjQnOaSXBY3AydkrVQG8XpCJ8EIh4bWpB02wQxxzPrK723ulRCzSKxkFLy_
下面是我的实现
// kd-tree.cpp : 定义控制台应用程序的入口点。
//
#include "stdafx.h"
#include
#include
#include
using namespace std;
#define KeyType double
class kdtree
{
public:
struct kdnode
{
kdnode*lnode, *rnode, *parent;
double*value;
int splitdim;//该节点在哪个维度分裂
kdnode()
{
lnode = rnode = parent = NULL;
}
};
private:
unsigned int B;//用于构建kdb树时指定叶子中包含的数据个数,默认为2,既包含[B/2,B)个数据
int dim;//维数
kdnode*root;
private:
//选择在哪个维度分裂,合理的选择分裂可以减小树的高度
int getsplitdim(vector&input);
//分裂数据集,left,right为分裂结果
void split_dataset(vector&input, int const splitdim, vector&left, vector&right);
void create(kdnode*&node, vector&input);
void goback();
double distance(KeyType*const aa, KeyType*const bb)
{
double dis = 0;
for (int i = 0; i < dim; i++)
dis += pow(double(aa[i] - bb[i]), double(2));
return sqrt(dis);
}
bool UDless(int const dth, KeyType* elem1, KeyType*elem2)
{
return elem1[dth] < elem2[dth];
}
public:
kdtree(int dimen = 2)
{
root = NULL;
_ASSERTE(dimen > 1);
dim = dimen;
}
KeyType* nearest(KeyType*const val);
//void insert();
void create(KeyType**&indata, int datanums);
kdnode*get_root(){ return root; }
~kdtree()
{
if (root == NULL)
return;
vectoraa, bb;
aa.push_back(root);
while (!aa.empty())
{
kdnode*cc = aa.back();
bb.push_back(cc);
aa.pop_back();
if (cc->lnode != NULL)
aa.push_back(cc->lnode);
if (cc->rnode != NULL)
aa.push_back(cc->rnode);
}
for (int i = 0; i < bb.size(); i++)
delete bb[i];
};
};
void kdtree::create(KeyType**&indata, int datanums)
{
for (int i = 0; i < datanums; i++)
{
for (int j = 0; j < dim; j++)
cout << indata[i][j] << " ";
cout << endl;
}
root = new kdnode;
vectorinput;
for (int i = 0; i < datanums; i++)
input.push_back(indata[i]);
create(root, input);
}
void kdtree::create(kdnode*&node, vector&input)
{
if (input.size() < 1)
return;
int splitinfo = getsplitdim(input);
node->value = input[input.size() / 2];
node->splitdim = splitinfo;
vectorleft, right;
//left,right为输出类型
split_dataset(input, splitinfo, left, right);
if (left.size() > 0)
{
kdnode*lnode = new kdnode;
lnode->parent = node;
node->lnode = lnode;
create(lnode, left);
}
if (right.size() > 0)
{
kdnode*rnode = new kdnode;
rnode->parent = node;
node->rnode = rnode;
create(rnode, right);
}
}
void kdtree::split_dataset(vector&input,
int const splitdim, vector&left, vector&right)
{
int nums = input.size();
left.assign(input.begin(), input.begin() + nums / 2);//将区间[first,last)的元素赋值到当前的vector容器中
input.erase(input.begin(), input.begin() + nums / 2 + 1);//将区间[first,last)的元素删除
right = input;
}
int kdtree::getsplitdim(vector&input)//根据方差决定在那一个维度分裂
{
double maxs = -1;
int splitdim;
int nums = input.size();
// 利用函数对象实现升降排序
struct CompNameEx{
CompNameEx(bool asce, int k) : asce_(asce), kk(k)
{}
bool operator()(KeyType*const& pl, KeyType*const& pr)
{
return asce_ ? pl[kk] < pr[kk] : pr[kk] < pl[kk]; // 《Eff STL》条款21: 永远让比较函数对相等的值返回false
}
private:
bool asce_;
int kk;
};
for (int i = 0; i < dim; i++)
{
double s = 0;
double mean = 0;
for (int j = 0; j < nums; j++)
mean += input[j][i];
mean /= double(nums);
for (int j = 0; j < nums; j++)
{
s += pow(double(input[j][i] - mean), double(2));
}
if (s > maxs)
{
splitdim = i;
maxs = s;
}
}
sort(input.begin(), input.end(), CompNameEx(true, splitdim));
return splitdim;
}
KeyType* kdtree::nearest(KeyType*const val)
{
if (root == NULL)
return NULL;
double mindis = 100000;
vectoraa;
kdnode*node = root;
KeyType*tt=NULL;
while (node != NULL)
{
aa.push_back(node);
if (val[node->splitdim] > node->value[node->splitdim])
node = node->rnode;
else
node = node->lnode;
}
double dis = distance(val, aa.back()->value);
if (dis < mindis)
{
mindis = dis;
tt = aa.back()->value;
}
aa.pop_back();
while (!aa.empty())
{
dis = distance(val, aa.back()->value);
if (dis < mindis)
{
mindis = dis;
tt = aa.back()->value;
int sd = aa.back()->splitdim;
if (val[sd] < aa.back()->value[sd])
{
kdnode*rr = aa.back()->rnode;
aa.pop_back();
if (rr)
aa.push_back(rr);
}
else
{
kdnode*ll = aa.back()->lnode;
aa.pop_back();
if (ll)
aa.push_back(ll);
}
}
else
aa.pop_back();
}
return tt;
}
int _tmain(int argc, _TCHAR* argv[])
{
kdtree kd(2);
KeyType bb[6][2] = { 2, 3, 5, 4, 9, 6, 4, 7, 8, 1, 7, 2 };// { 12, 45, 34, 12, 17, 34, 43, 889, 86, 54 };
KeyType** in = new KeyType*[6];
for (int i = 0; i < 6; i++)
{
for (int j = 0; j < 2; j++)
cout << bb[i][j] << " ";
cout << endl;
}
for (int i = 0; i < 6; i++)
in[i] = bb[i];
kdtree::kdnode*root = kd.get_root();
kd.create(in, 6);
root = kd.get_root();
KeyType hh[2] = { 2, 4.5 };
KeyType*n = kd.nearest(hh);
delete in;
system("pause");
return 0;
}
python里使用kd-tree
scipy.spatial.KDTree
>>> from scipy import spatial
>>> x, y = np.mgrid[0:5, 2:8]
>>> tree = spatial.KDTree(zip(x.ravel(), y.ravel()))
>>> tree.data
array([[0, 2],
[0, 3],
[0, 4],
[0, 5],
[0, 6],
[0, 7],
[1, 2],
[1, 3],
[1, 4],
[1, 5],
[1, 6],
[1, 7],
[2, 2],
[2, 3],
[2, 4],
[2, 5],
[2, 6],
[2, 7],
[3, 2],
[3, 3],
[3, 4],
[3, 5],
[3, 6],
[3, 7],
[4, 2],
[4, 3],
[4, 4],
[4, 5],
[4, 6],
[4, 7]])
>>> pts = np.array([[0, 0], [2.1, 2.9]])
>>> tree.query(pts)
(array([ 2. , 0.14142136]), array([ 0, 13]))
详见源码