KdTree 最近邻查找算法(C++描述)

根据统计学习方法写的KdTree实现,###

参考了这个博客的主要思路,但是在关于如何搜索最近邻上有些不同。
1.我采取在发现可能的路径后,采取扩展路径到叶子节点,生成一个新路径后重新计算最近路径。而这个博客中只检查了路径上与超球体相交的点。没有递归搜索
2.他的博客用利用方差确定分割的方向。我则选用了简单的依次更换策略。
#include
#include
#include
#include
using namespace std;
struct Node
{
double x;
double y;
};
struct KdTree
{
Node val;
int split; /描述根据X或Y进行划分/
KdTree* left;
KdTree* right;
};
KdTree myKdTree{};
const int N = 6;
const int dim = 2;
Node dataSet[N] = {
{ 2,3 },
{ 5,4 },
{ 9,6 },
{ 4,7 },
{ 8,1 },
{ 7,2 }
};
int time = 0;/记录寻找分割次数/
stack> search_path;/记录搜索过程的路经*/

/*结果结构*/
struct result {
    Node resNode;
    double dist;
};

/*X,Y维比较函数*/
bool compareX(Node a,Node b) {
    return a.x > b.x;
}
bool compareY(Node a, Node b) {
    return a.y > b.y;
}
void chooseSplit(Node unsortSet[],Node& splitData,int size) {
    if (time % 2 == 0) {
        /*根据x维分割*/
        sort(unsortSet, unsortSet + size, compareX);
    }
    else {
        /*根据y维分割*/
        sort(unsortSet, unsortSet + size, compareY);
    }
    int mid;
    if (size % 2 == 0) {
        mid = size / 2 - 1;
    }
    else {
        mid = size / 2;
    }
    splitData.x = unsortSet[mid].x;
    splitData.y = unsortSet[mid].y;
    time++;
}

/*构造kdTree*/
KdTree* build(int size,Node unsortSet[], KdTree* tree) {
    if (size == 0) {
        return 0;
    }
    else {
        int split;
        Node splitData;
        chooseSplit(unsortSet,splitData, size);
        Node leftset[100]{};
        Node rightset[100]{}; 
        int leftnum = 0;
        int rightnum = 0;
        if (time % 2 == 1) {
            /*根据x维分割,time加一后*/
            split = 0;
            for (int i = 0; i < size; i++) {
                if (splitData.x > unsortSet[i].x) {
                    leftset[leftnum] = unsortSet[i];
                    leftnum++;
                }
                else if(splitData.x < unsortSet[i].x) {
                    rightset[rightnum] = unsortSet[i];
                    rightnum++;
                }
            }
        }
        else {
            split = 1;
            for (int i = 0; i < size; i++) {
                if (splitData.y > unsortSet[i].y) {
                    leftset[leftnum] = unsortSet[i];
                    leftnum++;
                }
                else if (splitData.y < unsortSet[i].y) {
                    rightset[rightnum] = unsortSet[i];
                    rightnum++;
                }
            }
        }
        tree = new KdTree;
        tree->val = splitData;
        tree->split = split;
        tree->left = build(leftnum, leftset, tree->left);
        tree->right = build(rightnum, rightset, tree->right);
        return tree;

    }
}

/*计算距离 p=2*/
double distance(Node a, Node b) {
    return (a.x - b.x)*(a.x - b.x) + (a.y - b.y)*(a.y - b.y);
}

/*建立搜索路径*/
void buildpath(Node target, KdTree* tree) {
    KdTree* pSearch = tree;
    while (pSearch != NULL) {
        search_path.push(pSearch);
        if (pSearch->split == 0) {
            if (target.x < pSearch->val.x) {
                pSearch = pSearch->left;
            }
            else {
                pSearch = pSearch->right;
            }
        }
        else {
            if (target.y < tree->val.y) {
                pSearch = pSearch->left;
            }
            else {
                pSearch = pSearch->right;
            }
        }
    }
}

/*根据搜索路径查找最近邻*/
result findnearest (Node target,KdTree* tree){
    /*初始化搜索路径*/
    buildpath(target, tree);
    Node nearest = search_path.top()->val;
    double dist = distance(nearest, target);
    search_path.pop();
    //搜索潜在的路径上最近点。
    KdTree* pBack;
    while (search_path.size() != 0) {
        pBack = search_path.top();
        search_path.pop();
        if (pBack->left == NULL && pBack->right == NULL) {
            if (distance(pBack->val, target) < dist) {
                dist = distance(pBack->val, target);
                nearest = pBack->val;
            }
        }
        else {
            if (pBack->split == 0) {
                if (abs(target.x - pBack->val.x) < dist) {//X方向相交。
                    KdTree* newTree{};
                    if ((target.x > pBack->val.x)&&(pBack->left !=NULL)) {//点在右侧,向左搜索。
                        search_path.push(pBack->left);
                        newTree = pBack->left;
                    }
                    if ((target.x < pBack->val.x) && (pBack->right != NULL)) {
                        search_path.push(pBack->right);
                        newTree = pBack->right;
                    };
                    //搜索新发现的路径
                    buildpath(target, newTree);
                }
            }
            else {
                if (abs(target.y - pBack->val.y) < dist) {//Y方向相交。
                    KdTree* newTree{};
                    if ((target.y > pBack->val.y) && (pBack->left != NULL)) {//点在右侧,向左搜索。
                        search_path.push(pBack->left);
                        newTree = pBack->left;
                    }
                    if ((target.y < pBack->val.y) && (pBack->right != NULL)) {
                        search_path.push(pBack->right);
                        newTree = pBack->right;
                    };
                    //搜索新发现的路径
                    buildpath(target, newTree);
                }
            }
        }

    }
    return result{ nearest ,dist };
}
    
    

//打印树结构
void printNode(Node node) {
    cout << "("<val);
    if (root->left != NULL) {
        printTree_rootfirst(root->left);
    }
    if (root->right != NULL) {
        printTree_rootfirst(root->right);
    }
}

void printTree_leftfirst(KdTree* root) {
    if (root->left != NULL) {
        printTree_leftfirst(root->left);
    }
    printNode(root->val);
    if (root->right != NULL) {
        printTree_leftfirst(root->right);
    }
}
int main() {
    KdTree * root = NULL;
    root = build(N, dataSet, root); 
    Node target = {2,4.5};
    result res = findnearest(target,root);
    cout <<"最近距离:"<< res.dist << endl;
    cout <<"X方向:"<< res.resNode.x << endl;
    cout << "Y方向:" << res.resNode.y << endl;
    system("pause");
}

你可能感兴趣的:(KdTree 最近邻查找算法(C++描述))