简易KD树(C++)

k-d树介绍和类定义

k-d树( k-维树的缩写)是在k维欧几里德空间组织点的数据结构,可用在空间数据库和游戏优化等领域。具体来说,k-d树是每个节点都为k维点的二叉树。所有非叶子节点可以视作用一个超平面把空间分割成两个半空间。

下面是我按着描述粗糙写的一棵简易KD树,完成了树的创建、插入、按区域搜索、最近邻搜索等操作。本例中多维点由vector类型表示,其size表示点的维数。KD树的类定义如下。

#include
#include
#include
#include
using namespace std;

#ifndef KDTREE_H
#define KDTREE_H


class KDTree
{
	struct KDNode
	{
		bool m_isLeaf;
		vectorm_point;//k维的点
		int m_split;//在第几维被分开
		KDNode*m_parentNode;
		KDNode*m_leftNode;
		KDNode*m_rightNode;
	};
private:
	using st = vector::size_type;

	KDNode*m_root;//根节点
	int m_k;//k维
	int m_pointNum;//点的数量
	vector>m_points;//点的集合
public:
	//下面是给用户提供的接口
	KDTree(int k, vector>allpoints) :m_k(k)
	{
		m_root = new KDNode();
		m_root->m_isLeaf = false;
		m_root->m_leftNode = nullptr;
		m_root->m_rightNode = nullptr;

		m_pointNum = allpoints.size();
		m_points = allpoints;
		KDTreeBuild(allpoints, m_root);
	}

	void Insert(vectornewpoint);

	vector>SearchByRegion(vectorfrom, vectorto)const;

	vector SearchNearestNeighbor(vector goalpoint);
private:
        ~KDTree();
	void KDTreeBuild(vector>points, KDNode* root);
	//遍历寻找一个树中与目标点最近的点
	void SearchNearestByTree(vector goalpoint, double&curdis, const KDNode*treeroot, vector&nearestpoint);
	//递归查找区域内的点
	void SearchRecu(vectorfrom, vectorto, const KDNode*temp, vector>&nodes)const;
	//计算两个点的距离
	double CalDistance(vector point1, vector point2);
};

#endif // !1

kd树的构建

树的构建思路如下图

树的构建代码:

void KDTree::KDTreeBuild(vector>points, KDNode* root)
{
	int indexpart = 0, max = 0;
	vectortemp;
	for (st i = 0; i < m_k; i++)
	{
		temp.clear();
		for each (auto var in points)temp.push_back(var[i]);

		//计算平均值
		double ave = accumulate(temp.begin(), temp.end(), 0.0) / m_pointNum;

		//计算方差
		double accum = 0.0;
		for each(auto var in temp) 	accum += (var - ave)*(var - ave);

		if (accum > max)
		{
			max = accum;
			indexpart = i;
		}
	}
	//此时indexpart的值为当前要进行分裂的维数
	temp.clear();
	for each (auto var in points)temp.push_back(var[indexpart]);
	//找到中值
	sort(temp.begin(), temp.end());
	double median = temp[(temp.size()) >> 1];
	//将点分为左右两部分
	vector>leftpoints, rightpoints;
	for each (auto var in points)
	{
		if (var[indexpart] < median)
			leftpoints.push_back(var);

		if (var[indexpart] == median)
		{
			root->m_split = indexpart + 1;
			root->m_point = var;
		}
		if (var[indexpart] > median)
			rightpoints.push_back(var);
	}
	//递归建树
	if (leftpoints.size() == 0 && rightpoints.size() == 0)root->m_isLeaf = true;
	if (leftpoints.size() != 0)
	{
		root->m_leftNode = new KDNode();
		root->m_leftNode->m_parentNode = root;
		KDTreeBuild(leftpoints, root->m_leftNode);
	}
	if (rightpoints.size() != 0)
	{
		root->m_rightNode = new KDNode();
		root->m_rightNode->m_parentNode = root;
		KDTreeBuild(rightpoints, root->m_rightNode);
	}
}

节点的插入

节点的插入与二叉树类似(二叉树就是一维kd树),即从根节点开始,在每一个层比较对应的坐标值,如果小于则跟踪到左子节点,若大于则跟踪到右子节点,当一个空指针出现时,就找到了结点将要插入的位置。代码如下:

void KDTree::Insert(vectornewpoint)
{
	if (newpoint.size() != m_k)
	{
		cerr << "插入点维数与KD树不匹配" << endl;
		exit(1);
	}
	KDNode*temp = m_root;
	if (temp == nullptr)//若树为空树
	{
		temp = new KDNode();
		temp->m_isLeaf = true;
		temp->m_split = 1;
		temp->m_point = newpoint;
		return;
	}
	if (temp->m_isLeaf)//若树只有一个节点,做好被插入的准备
	{
		temp->m_isLeaf = false;
		int max = 0, partindex = 0;
		for (st i = 0; i < m_k; i++)
		{
			double delta = abs(newpoint[i] - temp->m_point[i]);
			if (delta > max)
			{
				max = delta;
				temp->m_split = i + 1;
			}
		}
	}
	while (true)
	{
		int partindex = temp->m_split - 1;
		KDNode*nextnode;
		if (newpoint[partindex] > temp->m_point[partindex])
		{
			if (temp->m_rightNode == nullptr)//插入
			{
				temp->m_rightNode = new KDNode();
				temp->m_rightNode->m_parentNode = temp;
				temp->m_rightNode->m_isLeaf = true;
				temp->m_rightNode->m_split = 1;
				temp->m_rightNode->m_point = newpoint;
				break;
			}
			else nextnode = temp->m_rightNode;
		}
		else
		{
			if (temp->m_leftNode == nullptr)//插入
			{
				temp->m_leftNode = new KDNode();
				temp->m_leftNode->m_parentNode = temp;
				temp->m_leftNode->m_isLeaf = true;
				temp->m_leftNode->m_split = 1;
				temp->m_leftNode->m_point = newpoint;
				break;
			}
			else nextnode = temp->m_leftNode;
		}

		if (nextnode->m_isLeaf)//如果下一个点是叶子节点,做好被插入的准备
		{
			nextnode->m_isLeaf = false;
			int max = 0, partindex = 0;
			for (st i = 0; i < m_k; i++)
			{
				double delta = abs(newpoint[i] - nextnode->m_point[i]);
				if (delta > max)
				{
					max = delta;
					nextnode->m_split = i + 1;
				}
			}
		}
		temp = nextnode;//往下走
	}
}

根据区域进行查询

所谓根据区域进行查询,即输入一个区域(需与KD树维数一致),得到KD树中在该区域内的所有坐标点。该搜索往往是多条路向下搜索,并将沿路的符合条件的点都进行记录。在本系统中,输入区域的方式为由用户输入两个点,起点的所有坐标需均小于终点,这样的两个点就能代表一个查询区域,该模块的详细代码如下。

vector>KDTree::SearchByRegion(vectorfrom, vectorto)const
{
	vector>result;
	if (from.size() != m_k || to.size() != m_k)
	{
		cerr << "搜索区域维数与KD树不匹配" << endl;
		exit(1);
	}
	for (st i = 0; i < m_k; i++)
	{
		if (from[i] > to[i])
		{
			cerr << "请保证区域起始点的所有坐标值小于区域终点" << endl;
			exit(1);
		}
	}
	SearchRecu(from, to, m_root, result);
	return result;
}

void KDTree::SearchRecu(vectorfrom, vectorto, const KDNode*temp, vector>&nodes)const
{
	if (temp == nullptr)return;

	int partindex = temp->m_split - 1;
	int value = temp->m_point[partindex];
	if (from[partindex] <= value && to[partindex] >= value)//当前点在范围内
	{
		bool inregion = true;
		for (st i = 0; i < m_k; i++)
		{
			if (from[i] > temp->m_point[i] || to[i] < temp->m_point[i])
				inregion = false;
		}
		if (inregion)nodes.push_back(temp->m_point);
		SearchRecu(from, to, temp->m_leftNode, nodes);
		SearchRecu(from, to, temp->m_rightNode, nodes);
	}
	else if (value > to[partindex])
		SearchRecu(from, to, temp->m_leftNode, nodes);

	else if (value < from[partindex])
		SearchRecu(from, to, temp->m_rightNode, nodes);
}

最近邻搜索

kd树的最近邻搜索算法应用广泛,其思路如下:

  1. 从根节点开始,递归的往下移。往左还是往右的决定方法与插入元素的方法一样(如果输入点在分区面的左边则进入左子节点,在右边则进入右子节点)。
  2. 一旦移动到叶节点,将该节点当作"当前最佳点"。
  3. 解开递归,并对每个经过的节点运行下列步骤:
    1. 如果当前所在点比当前最佳点更靠近输入点,则将其变为当前最佳点。
    2. 检查另一边子树有没有更近的点,如果有则从该节点往下找。
  4. 当根节点搜索完毕后完成最邻近搜索。

本例中详细代码如下:

double KDTree::CalDistance(vector point1, vector point2)
{
	if (point1.size() != point2.size())
	{
		cerr << "两个点的维度不相同";
		exit(1);
	}
	double distance = 0.0;
	for (st i = 0; i < point1.size(); i++)
		distance += pow((point1[i] - point2[i]), 2);

	return sqrt(distance);
}

vector KDTree::SearchNearestNeighbor(vector goalpoint)
{
	vectornearestpoint;
	KDNode*temp = m_root;
	//找到最邻近的叶子节点
	while (!temp->m_isLeaf)
	{
		int partindex = temp->m_split - 1;
		if (temp->m_leftNode != nullptr && goalpoint[partindex] < temp->m_point[partindex])
		{
			temp = temp->m_leftNode;
		}
		else if (temp->m_rightNode != nullptr)
		{
			temp = temp->m_rightNode;
		}
	}
	nearestpoint = temp->m_point;
	double curdis = CalDistance(goalpoint, nearestpoint);
	//向上回溯
	bool isleft = false;
	while (temp != m_root)
	{
		isleft = (temp == temp->m_parentNode->m_leftNode);//判断当前点是否其父节点的左子节点

		temp = temp->m_parentNode;//指针向上跟踪
		if (CalDistance(goalpoint, temp->m_point) < curdis)
		{
			nearestpoint = temp->m_point;
			curdis = CalDistance(goalpoint, nearestpoint);
		}
		int partindex = temp->m_split - 1;
		//若圆与另一区域有相交,即另一边子树可能有更近的点
		if (curdis > abs(temp->m_point[partindex] - goalpoint[partindex]))
		{
			if (isleft)
			{
				SearchNearestByTree(goalpoint, curdis, temp->m_rightNode, nearestpoint);
			}
			else SearchNearestByTree(goalpoint, curdis, temp->m_leftNode, nearestpoint);
		}
	}
	return nearestpoint;
}

void KDTree::SearchNearestByTree(vector goalpoint, double&curdis, const KDNode*treeroot, vector&nearestpoint)
{
	if (treeroot == nullptr)return;
	double newdis = CalDistance(goalpoint, treeroot->m_point);
	if (newdis < curdis)
	{
		curdis = newdis;
		nearestpoint = treeroot->m_point;
	}
	SearchNearestByTree(goalpoint, curdis, treeroot->m_leftNode, nearestpoint);
	SearchNearestByTree(goalpoint, curdis, treeroot->m_rightNode, nearestpoint);
}

本人在这里只是粗略的实现了kd树的几个功能,应该还有很多细节可以完善。谢谢观看:)

 

你可能感兴趣的:(C++)