使用KD-Tree树查找最近邻点 - 二维

文章目录

  • KD-Tree
    • 介绍
    • 用法
    • KD-Tree的构建
    • 最邻近点的查找

KD-Tree

介绍

kd-tree,是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。 主要应用于多维空间关键数据的搜索(如:范围搜索和最近邻搜索)。K-D树是二进制空间分割树的特殊的情况。

用法

QT图形视图框架绘制曲线图和Smith图_qtcreator画smith图-CSDN博客 在这篇博客中绘图的mark点支持,最初我是使用set的自排序功能来查找最近的点,虽然效率能跟得上使用,但是这种方式弊端非常明显。每次移动的时候就需要重写计算距离,然后重新生成这个set。这个方式浪费不仅浪费空间也浪费时间。

所以在后续更新的时候使用了KD-Tree做了最邻近点的计算,比上述暴力的的排序方式快了不止30倍(这里是只测试了4W个点,使用排序打印的ms为0,使用上面的方式打印的在30左右)。

接下来展示我的代码:

kdtree.h

#ifndef TURBOPLOT_KDTREE_H
#define TURBOPLOT_KDTREE_H

/******************************************************************
 * @brief     求最邻近点算法
 * @author    turbo
 * @class     KDTree
 * @date      2023年11月2日10:59:27
 * @version   0.0.1
 * @property  root_            kdtree跟节点
 ******************************************************************/

#include 
#include 

namespace turbo
{
    class KDTree
    {
    public:
        struct KDTreeNode
        {
            QPointF      point;
            KDTreeNode*  left;
            KDTreeNode*  right;
        };
        KDTree();
        ~KDTree();
        /**
         * @brief 设置点,根据点生成对应的KDTree
         * @param points
         */
        void setPoints(QVector<QPointF> &points);
        /**
         * @brief 获取离入参点最近的点
         * @param point
         * @return
         */
        QPointF getNearestPoint(const QPointF &point);

    protected:
        /**
         * @brief 计算两个点之间的欧氏距离
         * @param point1
         * @param point2
         * @return
         */
        static double euclideanDistance(const QPointF& point1, const QPointF& point2);
        /**
         * @brief 根据两个点的x比较大小
         * @param point1
         * @param point2
         * @return
         */
        static bool compareX(const QPointF& point1, const QPointF& point2);
        /**
         * @brief 根据两个点的y比较大小
         * @param point1
         * @param point2
         * @return
         */
        static bool compareY(const QPointF& point1, const QPointF& point2);
        /**
         * @brief 构建KD树
         * @param points
         * @param compareX
         * @return
         */
        KDTreeNode* buildKdTree(QVector<QPointF>& points, int depth);
        /**
         * @brief 计算最邻近点
         * @param root
         * @param target
         * @param nearest
         * @param minDist
         * @param depth
         */
        void findNearestNeighbor(KDTreeNode* root, QPointF target, QPointF& nearest, double& minDist, int depth);
        /**
         * @brief 清理kdtree 释放空间
         * @param root
         */
        void clearKdTree(KDTreeNode* root);

    private:
        KDTreeNode *root_;
    };
}

#endif //TURBOPLOT_KDTREE_H

kdtree.cpp

#include "kdtree.h"
#include 
#include 

namespace turbo
{
    KDTree::KDTree() : root_(nullptr)
    {

    }

    KDTree::~KDTree()
    {
        clearKdTree(root_);
    }

    // 清理KD树
    void KDTree::clearKdTree(KDTreeNode* root)
    {
        if (root == nullptr)
        {
            return;
        }
        clearKdTree(root->left);
        clearKdTree(root->right);
        delete root;
        root = nullptr;
    }

    double KDTree::euclideanDistance(const QPointF &point1, const QPointF &point2)
    {
        double dx = point1.x() - point2.x();
        double dy = point1.y() - point2.y();
        return std::sqrt(dx*dx + dy*dy);
    }

    bool KDTree::compareX(const QPointF &point1, const QPointF &point2)
    {
        return point1.x() < point2.x();
    }

    bool KDTree::compareY(const QPointF &point1, const QPointF &point2)
    {
        return point1.y() < point2.y();
    }

    KDTree::KDTreeNode *KDTree::buildKdTree(QVector<QPointF> &points, int depth)
    {
        if (points.isEmpty())
        {
            return nullptr;
        }
        // 根据分割轴对点进行排序
        if (depth % 2 == 0)
        {
            std::sort(points.begin(), points.end(), compareX);
        }
        else
        {
            std::sort(points.begin(), points.end(), compareY);
        }

        // 选择中间点作为根节点
        int mid = points.size() / 2;
        auto* root = new KDTreeNode();
        root->point = points[mid];
        QVector<QPointF> points1 = QVector<QPointF>(points.begin(), points.begin() + mid);
        QVector<QPointF> points2 = QVector<QPointF>(points.begin() + mid + 1, points.end());
        root->left = buildKdTree(points1, depth + 1);
        root->right = buildKdTree(points2, depth + 1);
        return root;
    }

    // 在KD树中查找最近邻点
    void KDTree::findNearestNeighbor(KDTree::KDTreeNode* root, QPointF target, QPointF& nearest, double& minDist, int depth)
    {
        if (root == nullptr)
        {
            return;
        }
        // 计算当前节点到目标点的欧氏距离
        double dist = euclideanDistance(root->point, target);
        // 如果当前节点更近,则更新最近邻点和最小距离
        if (dist < minDist)
        {
            nearest = root->point;
            minDist = dist;
        }
        // 根据深度选择分割轴
        int axis = depth % 2;
        // 根据分割轴比较目标点和当前节点,并决定遍历顺序
        if ((axis == 0 && target.x() < root->point.x()) || (axis == 1 && target.y() < root->point.y()))
        {
            findNearestNeighbor(root->left, target, nearest, minDist, depth + 1);
            if ((axis == 0 && target.x() + minDist >= root->point.x()) || (axis == 1 && target.y() + minDist >= root->point.y()))
            {
                findNearestNeighbor(root->right, target, nearest, minDist, depth + 1);
            }
        }
        else
        {
            findNearestNeighbor(root->right, target, nearest, minDist, depth + 1);
            if ((axis == 0 && target.x() - minDist <= root->point.x()) || axis == 1 && target.y() - minDist <= root->point.y())
            {
                findNearestNeighbor(root->left, target, nearest, minDist, depth + 1);
            }
        }
    }

    void KDTree::setPoints(QVector<QPointF> &points)
    {
        clearKdTree(root_);
        root_ = buildKdTree(points, 0);
    }

    QPointF KDTree::getNearestPoint(const QPointF &point)
    {
        QPointF nearest;
        double minDist = std::numeric_limits<double>::max();
        // 在KD树中查找最近邻点
        findNearestNeighbor(root_, point, nearest, minDist, 0);
        return nearest;
    }
}

这里面就说一下构建时候的思路和查找时候的思路

KD-Tree的构建

KDTree::KDTreeNode *KDTree::buildKdTree(QVector<QPointF> &points, int depth)
{
    if (points.isEmpty())
    {
        return nullptr;
    }
    // 根据分割轴对点进行排序
    if (depth % 2 == 0)
    {
        std::sort(points.begin(), points.end(), compareX);
    }
    else
    {
        std::sort(points.begin(), points.end(), compareY);
    }

    // 选择中间点作为根节点
    int mid = points.size() / 2;
    auto* root = new KDTreeNode();
    root->point = points[mid];
    QVector<QPointF> points1 = QVector<QPointF>(points.begin(), points.begin() + mid);
    QVector<QPointF> points2 = QVector<QPointF>(points.begin() + mid + 1, points.end());
    root->left = buildKdTree(points1, depth + 1);
    root->right = buildKdTree(points2, depth + 1);
    return root;
}
  1. 检查点集是否为空,如果是则返回空指针。
  2. 根据当前深度选择分割轴,偶数深度使用compareX函数排序,奇数深度使用compareY函数排序。
  3. 选择中间点作为根节点,并创建一个新的KDTreeNode对象来存储该点。
  4. 将点集分为两部分,分别是从起始位置到中间位置的点集points1和从中间位置+1到结束位置的点集points2。
  5. 递归地调用buildKdTree函数来构建左子树,并将其赋值给根节点的left指针。
  6. 递归地调用buildKdTree函数来构建右子树,并将其赋值给根节点的right指针。

这么做的主要目的就是将一整个区域分成不同的区域

最邻近点的查找

// 在KD树中查找最近邻点
void KDTree::findNearestNeighbor(KDTree::KDTreeNode* root, QPointF target, QPointF& nearest, double& minDist, int depth)
{
    if (root == nullptr)
    {
        return;
    }
    // 计算当前节点到目标点的欧氏距离
    double dist = euclideanDistance(root->point, target);
    // 如果当前节点更近,则更新最近邻点和最小距离
    if (dist < minDist)
    {
        nearest = root->point;
        minDist = dist;
    }
    // 根据深度选择分割轴
    int axis = depth % 2;
    // 根据分割轴比较目标点和当前节点,并决定遍历顺序
    if ((axis == 0 && target.x() < root->point.x()) || (axis == 1 && target.y() < root->point.y()))
    {
        findNearestNeighbor(root->left, target, nearest, minDist, depth + 1);
        if ((axis == 0 && target.x() + minDist >= root->point.x()) || (axis == 1 && target.y() + minDist >= root->point.y()))
        {
            findNearestNeighbor(root->right, target, nearest, minDist, depth + 1);
        }
    }
    else
    {
        findNearestNeighbor(root->right, target, nearest, minDist, depth + 1);
        if ((axis == 0 && target.x() - minDist <= root->point.x()) || axis == 1 && target.y() - minDist <= root->point.y())
        {
            findNearestNeighbor(root->left, target, nearest, minDist, depth + 1);
        }
    }
}

这里注释都写的很清楚了,本质上就是递归查找。这边要说明白得绘图了,单纯用文字描述的话,确实有点不太清楚。但是我绘图技术太菜,直接看代码理解吧。

判断的地方,因为之前划分的时候就是按照深度划分的,这里根据深度来决定是使用x还是y来找最近点。

你可能感兴趣的:(数据结构和算法,数据结构)