KNN算法的KD树C++实现

KD树本质上是一种二叉树,它拥有具备特定排列顺序的分裂节点以便查找数据,即在二叉排序树之中,某个分裂节点左子树的值均小于分裂节点的值,而右侧均大于分裂节点的值,如果用中序遍历这棵树,它的打印顺序将是从小到大递增的顺序。当然剩下的科普就不说了,这也是在PCL库当中,最常用的轮子之一,处理点云速度非常快。另外,KNN算法是机器学习训练的一种非常有效的分类方法,手撸这个算法就显得很重要了。当然,在企业级应用中,还是用别人的轮子吧~

首先是数据与树节点的结构体,在树节点中,具备数据、分裂维与左右孩子:

struct Data
{
  int i;
  float x;
  float y;
};

struct Tnode
{
  struct Data data;
  int split;
  struct Tnode *left;
  struct Tnode *right;
};

 接下来是判断每次分裂时的分裂维度,在这里只选择二维平面上的两个方向进行比较:

bool compareX(struct Data a, struct Data b)
{
  return a.x < b.x;
}

bool compareY(struct Data a, struct Data b)
{
  return a.y < b.y;
}

bool equal(struct Data a, struct Data b)
{
  if(a.x == b.x&&a.y == b.y)
  {
    return true;
  }
  else
  {
    return false;
  }
}

//在每次决定树的分裂节点时,为了枝的分散效果需比较XY维度的方差,取大者为分裂依据
void chooseSplit(std::vector pointsData, int size, int &split, struct Data &splitchoice)
{
  float temp1, temp2;
  temp1 = temp2 = 0;

  //计算X方向上的方差
  for(int i = 0;i < size;++i)
  {
    temp1 += float(1.0/float(size)) * pointsData[i].x * pointsData[i].x;
    temp2 += float(1.0/float(size)) * pointsData[i].x;
  }
  float v1 = temp1 - temp2 * temp2;
  temp1 = temp2 = 0;

  //计算Y方向上的方差
  for (int j = 0;j < size;++j)
  {
    temp1 += float(1.0/float(size)) * pointsData[i].y * pointsData[i].y;
    temp2 += float(1.0/float(size)) * pointsData[i].y;
  }
  float v2 = temp1 - temp2 * temp2;

  //取大者为分裂维
  split = v1 > v2 ? 0:1;
  struct Data tempx, tempy;
  if(0 == split)
  {
    sort(pointsData.begin(), pointsData.end(), compareX);
  }
  else
  {
    sort(pointsData.begin(), pointsData.end(), compareY);
  }

  //给分裂节点赋值
  splitchoice.i = pointsData[size/2].i;
  splitchoice.x = pointsData[size/2].x;
  splitchoice.y = pointsData[size/2].y;
}

接下来是递归建立KD树,从根节点开始依次划分二维空间区域直到左右孩子均为空指针。

Tnode* buildKdtree(std::vector pointsData, int size, Tnode* T)
{
  //递归调用开始,只有当子树高度为0时结束调用
  if(size == 0)
  {
    return NULL;
  }
  else
  {
    int split;
    struct Data data;
    chooseSplit(pointsData, size, split, data);
    std::vector pointsDataRight;
    std::vector pointsDataLeft;
    pointsDataRight.clear();
    pointsDataLeft.clear();
    int sizeLeft, sizeRight;
    sizeLeft = sizeRight = 0;

    if(0 == split)
    {
      for(int i = 0;i < size;++i)
      {
        if(!equal(pointsData[i], data) && pointsData[i].x <= data.x)
        {
          pointsDataLeft.push_back(pointsData[i]);
          //cout<<"pointsDataLeft["<float distance(struct Data a, struct Data b)
{
   float dist = (a.x - b.x)*(a.x - b.x)+(a.y - b.y)*(a.y - b.y);
   return sqrt(dist);
}

首先第一步,直接查找直到叶节点,这一步将查询点从根节点放入,依次根据分裂维比较,直至查找到叶节点,那么我们得到一个疑似的“最近点”。

void findNearest(std::vector pointsData, struct Data query)
{
  Tnode* nearest;
  //设置一个栈作为寻找最近点的路径
	stack searchPath;

  //将根节点加入到栈中
  T->visited = true;
  searchPath.push(T);

  //向下搜索直到叶节点
	while(T->left != NULL && T->right != NULL)
	{
	  if(0 == T->split)
	  {
      if(query.x <= T->data.x)
        T = T->left;
      else if(query.x > T->data.x)
        T = T->right;
    }
    else if(1 == T->split)
    {
      if(query.y <= T->data.y)
        T = T->left;
      else if(query.y > T->data.y)
        T = T->right;
    }
    T->visited = true;
	  searchPath.push(T);
	}

  //将此时的叶节点设为最近点
  nearest = T;

  //当前节点的指针,以及当前节点的父节点指针
  Tnode* current, *current_parent;
  //临时指针
  Tnode* temp;
  //与当前节点的距离,以及与其兄弟子空间超平面的距离
  double dist, dist_bro;
  //如果某节点的兄弟子空间已被判断过就略去并继续回溯
  bool bro_visited = false;
  //向上回溯寻找是否具备更好的点
  while(!searchPath.empty())
  {
    current = searchPath.top();
    searchPath.pop();
    dist = distance(query, current->data);
    
    if(dist < distance(query, nearest->data))
    {
      nearest = current;
    }

    if(!searchPath.empty())
    {
      current_parent = searchPath.top();
      if(current = current_parent->left)
      {
        if(current_parent->right->visited) bro_visited = true;
        else bro_visited = false;
      }
      else if(current = current_parent->right)
      {
        if(current_parent->left->visited) bro_visited = true;
        else bro_visited = false;
      }

      if(!bro_visited)
      {
        //计算兄弟子空间超平面的距离,首先需要判断超平面是平行于x轴还是y轴
        if(0 == current_parent->split)
        {
          dist_bro = fabs(current_parent->data.x - query.x);
          temp = current_parent->right;
        }
        else if(1 == current_parent->split)
        {
          dist_bro = fabs(current_parent->data.y - query.y);
          temp = current_parent->left;
        }
        //如果以查询点为圆心,与当前节点距离为半径画的圆侵犯了兄弟节点的子空间,那么跳到隔壁子空间去查找
        if(dist > dist_bro)
        {
          //将兄弟节点纳入栈中
          temp->visited = true;
          searchPath.push(temp);
          //同样也是先直接搜索到叶节点为止,再回溯
          while(temp->left != NULL && temp->right != NULL)
          {
        	  if(0 == temp->split)
        	  {
              if(query.x <= temp->data.x)
                temp = temp->left;
              else if(query.x > temp->data.x)
                temp = temp->right;
            }
            else if(1 == temp->split)
            {
              if(query.y <= temp->data.y)
                temp = temp->left;
              else if(query.y > temp->data.y)
                temp = temp->right;
            }
            temp->visited = true;
            searchPath.push(temp);
          }

          //既然已经侵犯了兄弟子空间,那判断这次搜索到的叶节点是否比原来的叶节点更优,若是则将现在搜索到的叶节点设为最近点
          if(dist > distance(temp->data, query))
            nearest = temp;
        }
        //如果并未侵犯兄弟的子空间,那么比较父节点与当前节点谁更优
        else
        {
          if(dist > distance(current_parent->data, query))
            nearest = current_parent;
        }
      }
      else
      {
        if(dist > distance(current_parent->data, query))
          nearest = current_parent;
      }
    }

  }

}

int main()
{
	struct Data data;
	std::vector pointsData;
	int i = 1;
	double x, y;
	cout<<"请输入x,y坐标"<>x>>y)
	{
		cout<<"请输入x,y坐标"<

如果是K近邻的话,目前比较好的方法是利用优先级队列来维护K个最近点,当查找到比队列中更好的点时则将末尾的节点拿出,这就是大顶堆的算法。

先占坑,明天再写。

 

你可能感兴趣的:(数据结构与算法)