kd-tree的实现

参考百度百科http://baike.baidu.com/link?url=JLBeRUhL6WLyp8R6TAFDD8swLfazjQnOaSXBY3AydkrVQG8XpCJ8EIh4bWpB02wQxxzPrK723ulRCzSKxkFLy_

下面是我的实现

 

// kd-tree.cpp : 定义控制台应用程序的入口点。
//

#include "stdafx.h"
#include
#include
#include 

using namespace std;
#define KeyType double

class kdtree
{
public:
	struct kdnode
	{
		kdnode*lnode, *rnode, *parent;
		double*value;
		int splitdim;//该节点在哪个维度分裂
		kdnode()
		{
			lnode = rnode = parent = NULL;
		}
	};
private:
	unsigned int B;//用于构建kdb树时指定叶子中包含的数据个数,默认为2,既包含[B/2,B)个数据
	int dim;//维数
	kdnode*root;
private:
	//选择在哪个维度分裂,合理的选择分裂可以减小树的高度
	int getsplitdim(vector&input);
	//分裂数据集,left,right为分裂结果
	void split_dataset(vector&input, int const splitdim, vector&left, vector&right);
	void create(kdnode*&node, vector&input);
	void goback();
	double distance(KeyType*const aa, KeyType*const bb)
	{
		double dis = 0;
		for (int i = 0; i < dim; i++)
			dis += pow(double(aa[i] - bb[i]), double(2));
		return sqrt(dis);
	}
	bool UDless(int const dth, KeyType* elem1, KeyType*elem2)
	{
		return elem1[dth] < elem2[dth];
	}
public:
	kdtree(int dimen = 2)
	{
		root = NULL;
		_ASSERTE(dimen > 1);
		dim = dimen;
	}
	KeyType* nearest(KeyType*const val);
	//void insert();

	void create(KeyType**&indata, int datanums);
	kdnode*get_root(){ return root; }
	~kdtree()
	{
		if (root == NULL)
			return;
		vectoraa, bb;
		aa.push_back(root);
		while (!aa.empty())
		{
			kdnode*cc = aa.back();
			bb.push_back(cc);
			aa.pop_back();
			if (cc->lnode != NULL)
				aa.push_back(cc->lnode);
			if (cc->rnode != NULL)
				aa.push_back(cc->rnode);
		}
		for (int i = 0; i < bb.size(); i++)
			delete bb[i];

	};
};

void kdtree::create(KeyType**&indata, int datanums)
{
	for (int i = 0; i < datanums; i++)
	{
		for (int j = 0; j < dim; j++)
			cout << indata[i][j] << "   ";
		cout << endl;
	}
	root = new kdnode;
	vectorinput;
	for (int i = 0; i < datanums; i++)
		input.push_back(indata[i]);
	create(root, input);
}

void kdtree::create(kdnode*&node, vector&input)
{
	if (input.size() < 1)
		return;
	
	int splitinfo = getsplitdim(input);
	node->value = input[input.size() / 2];
	node->splitdim = splitinfo;
	vectorleft, right;
	//left,right为输出类型
	split_dataset(input, splitinfo, left, right);
	if (left.size() > 0)
	{
		kdnode*lnode = new kdnode;
		lnode->parent = node;
		node->lnode = lnode;
		create(lnode, left);
	}
	if (right.size() > 0)
	{
		kdnode*rnode = new kdnode;
		rnode->parent = node;
		node->rnode = rnode;
		create(rnode, right);
	}

}


void kdtree::split_dataset(vector&input,
	int const splitdim, vector&left, vector&right)
{
	int nums = input.size();
	left.assign(input.begin(), input.begin() + nums / 2);//将区间[first,last)的元素赋值到当前的vector容器中
	input.erase(input.begin(), input.begin() + nums / 2 + 1);//将区间[first,last)的元素删除
	right = input;
}


int kdtree::getsplitdim(vector&input)//根据方差决定在那一个维度分裂
{
	double maxs = -1;
	int splitdim;
	int nums = input.size();
	// 利用函数对象实现升降排序  
	struct CompNameEx{
		CompNameEx(bool asce, int k) : asce_(asce), kk(k)
		{}
		bool operator()(KeyType*const& pl, KeyType*const& pr)
		{
			return asce_ ? pl[kk] < pr[kk] : pr[kk] < pl[kk]; // 《Eff STL》条款21: 永远让比较函数对相等的值返回false  
		}
	private:
		bool asce_;
		int kk;
	};
	for (int i = 0; i < dim; i++)
	{
		double s = 0;
		double mean = 0;
		for (int j = 0; j < nums; j++)
			mean += input[j][i];
		mean /= double(nums);
		for (int j = 0; j < nums; j++)
		{
			s += pow(double(input[j][i] - mean), double(2));
		}
		if (s > maxs)
		{
			splitdim = i;
			maxs = s;
		}
	}
	sort(input.begin(), input.end(), CompNameEx(true, splitdim));
	return splitdim;
}

KeyType* kdtree::nearest(KeyType*const val)
{
	if (root == NULL)
		return NULL;
	double mindis = 100000;
	vectoraa;
	kdnode*node = root;
	KeyType*tt=NULL;
	while (node != NULL)
	{
		aa.push_back(node);
		if (val[node->splitdim] > node->value[node->splitdim])
			node = node->rnode;
		else
			node = node->lnode;
	}
	double dis = distance(val, aa.back()->value);
	if (dis < mindis)
	{
		mindis = dis;
		tt = aa.back()->value;
	}
	aa.pop_back();
	while (!aa.empty())
	{
		dis = distance(val, aa.back()->value);
		if (dis < mindis)
		{
			mindis = dis;
			tt = aa.back()->value;
			int sd = aa.back()->splitdim;
			if (val[sd] < aa.back()->value[sd])
			{
				kdnode*rr = aa.back()->rnode;
				aa.pop_back();
				if (rr)
				aa.push_back(rr);
			}
			else
			{
				kdnode*ll = aa.back()->lnode;
				aa.pop_back();
				if (ll)
				aa.push_back(ll);
			}
		}
		else
			aa.pop_back();
	}

	return tt;
}




int _tmain(int argc, _TCHAR* argv[])
{
	kdtree kd(2);
	KeyType bb[6][2] = { 2, 3, 5, 4, 9, 6, 4, 7, 8, 1, 7, 2 };// { 12, 45, 34, 12, 17, 34, 43, 889, 86, 54 };
	KeyType** in = new KeyType*[6];
	for (int i = 0; i < 6; i++)
	{
		for (int j = 0; j < 2; j++)
			cout << bb[i][j] << "   ";
		cout << endl;
	}
	for (int i = 0; i < 6; i++)
		in[i] = bb[i];
	kdtree::kdnode*root = kd.get_root();
	kd.create(in, 6);
	root = kd.get_root();
	KeyType hh[2] = { 2, 4.5 };
	KeyType*n = kd.nearest(hh);

	
	delete in;
	system("pause");
	return 0;
}

 

 

 

 

 

 

python里使用kd-tree

scipy.spatial.KDTree

 

 

>>> from scipy import spatial
>>> x, y = np.mgrid[0:5, 2:8]
>>> tree = spatial.KDTree(zip(x.ravel(), y.ravel()))
>>> tree.data
array([[0, 2],
       [0, 3],
       [0, 4],
       [0, 5],
       [0, 6],
       [0, 7],
       [1, 2],
       [1, 3],
       [1, 4],
       [1, 5],
       [1, 6],
       [1, 7],
       [2, 2],
       [2, 3],
       [2, 4],
       [2, 5],
       [2, 6],
       [2, 7],
       [3, 2],
       [3, 3],
       [3, 4],
       [3, 5],
       [3, 6],
       [3, 7],
       [4, 2],
       [4, 3],
       [4, 4],
       [4, 5],
       [4, 6],
       [4, 7]])
>>> pts = np.array([[0, 0], [2.1, 2.9]])
>>> tree.query(pts)
(array([ 2.        ,  0.14142136]), array([ 0, 13]))

详见源码

 

 

 

 

你可能感兴趣的:(数据结构与算法)