k-d tree代码解析

http://www.cnblogs.com/eyeszjwang/articles/2432465.html 

 上一篇较详细地介绍了k-d树算法。本文来讲解具体的实现代码。

  首先是一些数据结构的定义。我们先来定义单个数据,代码如下:
复制代码

//单个数据向量结构定义
struct _Examplar
{
public:

    _Examplar():dom_dims(0){}        //数据维度初始化为0
  //带有完整的两个参数的constructor,这里const是为了保护原数据不被修改
    _Examplar(const std::vector<double> elt, int dims)   
    {                                                    
        if(dims > 0)
        {
            dom_elt = elt;
            dom_dims = dims;
        }
        else
        {
            dom_dims = 0;
        }
    }

(一些重载的构造函数和运算符,元素的访问控制函数等)
复制代码

        _Examplar(int dims)    //只含有维度信息的constructor
    {
        if(dims > 0)
        {
            dom_elt.resize(dims);
            dom_dims = dims;
        }
        else
        {
            dom_dims = 0;
        }
    }
    _Examplar(const _Examplar& rhs)    //copy-constructor
    {
        if(rhs.dom_dims > 0)
        {
            dom_elt = rhs.dom_elt;
            dom_dims = rhs.dom_dims;
        }
        else
        {
            dom_dims = 0;
        }
    }
    _Examplar& operator=(const _Examplar& rhs)    //重载"="运算符
    {
        if(this == &rhs) 
            return *this;

        releaseExamplarMem();

        if(rhs.dom_dims > 0)
        {
            dom_elt = rhs.dom_elt;
            dom_dims = rhs.dom_dims;
        }

        return *this;
    }
    ~_Examplar()
    {
    }
    double& dataAt(int dim)    //定义访问控制函数
    {
        assert(dim < dom_dims);
        return dom_elt[dim];
    }
    double& operator[](int dim)    //重载"[]"运算符,实现下标访问
    {
        return dataAt(dim);
    }
    const double& dataAt(int dim) const    //定义只读访问函数
    {
        assert(dim < dom_dims);
        return dom_elt[dim];
    }
    const double& operator[](int dim) const    //重载"[]"运算符,实现下标只读访问
    {
        return dataAt(dim);
    }
    void create(int dims)    //创建数据向量
    {
        releaseExamplarMem();
        if(dims > 0)
        {
            dom_elt.resize(dims);    //控制数据向量维度
            dom_dims = dims;
        }
    }
    int getDomDims() const     //获得数据向量维度信息
    {
        return dom_dims;
    }
    void setTo(double val)    //数据向量初始化设置
    {
        if(dom_dims > 0)
        {
            for(int i=0;i<dom_dims;i++)
            {
                dom_elt[i] = val;
            }
        }
    }
private:
    void releaseExamplarMem()    //清除现有数据向量
    {
        dom_elt.clear();
        dom_dims = 0;
    }

复制代码

private:
    std::vector<double> dom_elt;    //每个数据定义为一个double类型的向量
    int dom_dims;                   //数据向量的维度
};

复制代码

  结构_Examplar定义了单个数据节点的结构,主要包含的信息有:1.数据向量本身;2.数据向量的维度。接下来定义一整个数据集的结构,代码如下:
复制代码

//数据集结构定义
class ExamplarSet : public TrainData      //整个数据集类,由一个抽象类TrainData派生
{
private:
    //_Examplar *_ex_set;
    std::vector<_Examplar> _ex_set;       //定义含有若干个_Examplar类数据向量的数据集
    int _size;                            //数据集大小
    int _dims;                            //数据集中每个数据向量的维度
public:

(一些重载的构造函数运算符,元素访问控制函数等)
复制代码

        ExamplarSet():_size(0), _dims(0){}
    ExamplarSet(std::vector<_Examplar> ex_set, int size, int dims);
    ExamplarSet(int size, int dims);
    ExamplarSet(const ExamplarSet& rhs);
    ExamplarSet& operator=(const ExamplarSet& rhs);
    ~ExamplarSet(){}

    _Examplar& examplarAt(int idx)
    { 
        assert(idx < _size);
        return _ex_set[idx]; 
    }
    _Examplar& operator[](int idx)
    {
        return examplarAt(idx);
    }
    const _Examplar& examplarAt(int idx) const
    {
        assert(idx < _size);
        return _ex_set[idx];
    }
    void create(int size, int dims);
    int getDims() const { return _dims;}
    int getSize() const { return _size;}
    _HyperRectangle calculateRange();
    bool empty() const
    {
        return (_size == 0);
    }

复制代码

    void sortByDim(int dim);     //按某个方向维的排序函数
    bool remove(int idx);        //去除数据集中排序后指定位置的数据向量
    void push_back(const _Examplar& ex)    //添加某个数据向量至数据集末尾
    {
        _ex_set.push_back(ex);
        _size++;
    }
    int readData(char *strFilePath);    //从文件读取数据集
private:
    void releaseExamplarSetMem()        //清除现有数据集
    {
        _ex_set.clear();
        _size = 0;
    }
};

复制代码

  类ExamplarSet定义了整个数据集的结构,其包含的主要信息有:1.含有若干个_Examplar类数据向量的数据集;2.数据集的大小;3.每个数据向量的维度。以上两个结构是整个算法两个基本的数据结构,这里的代码只是展示其主要包含的结构信息,详细的定义及函数实现代码请参看附件。

  接下来就要定义k-d tree的结构。同样采用上述由点定义到集定义的思路,我们先来定义k-d tree中一个节点结构,代码如下:
复制代码

//k-d tree节点结构定义
class KDTreeNode    
{
private:
    int _split_dim;                //该节点的最大区分度方向维
    _Examplar _dom_elt;            //该节点的数据向量
    _HyperRectangle _range_hr;     //表示数据范围的超矩形结构
public:
    KDTreeNode *_left_child, *_right_child, *_parent;        //该节点的左右子树和父节点

(一些重载的构造函数,元素访问控制函数等)
复制代码

public:
    KDTreeNode():_left_child(0), _right_child(0), _parent(0), 
        _split_dim(0){}
    KDTreeNode(KDTreeNode *left_child, KDTreeNode *right_child, 
        KDTreeNode *parent, int split_dim, _Examplar dom_elt, _HyperRectangle range_hr):
    _left_child(left_child), _right_child(right_child), _parent(parent),
        _split_dim(split_dim), _dom_elt(dom_elt), _range_hr(range_hr){}
    KDTreeNode(const KDTreeNode &rhs);
    KDTreeNode& operator=(const KDTreeNode &rhs);
    _Examplar& getDomElt() { return _dom_elt; }
    _HyperRectangle& getHyperRectangle(){ return _range_hr; }
    int& splitDim(){ return _split_dim; }
    void create(KDTreeNode *left_child, KDTreeNode *right_child, 
        KDTreeNode *parent, int split_dim, _Examplar dom_elt,  _HyperRectangle range_hr);

复制代码

};

复制代码

  类KDTreeNode就是按照前一篇表1所述定义的。需要注意的是_HyperRectangle这一结构,它表示的就是这一节点所代表的空间范围Range,其定义如下:
复制代码

struct _HyperRectangle    //定义表示数据范围的超矩形结构
{
    _Examplar min;        //统计数据集中所有数据向量每个维度上最小值组成的一个数据向量
    _Examplar max;        //统计数据集中所有数据向量每个维度上最大值组成的一个数据向量

(一些重载的构造函数)
复制代码

        _HyperRectangle() {}
    _HyperRectangle(_Examplar mx, _Examplar mn)
    {
        assert (mx.getDomDims() == mn.getDomDims());
        min = mn;
        max = mx;
    }
    _HyperRectangle(const _HyperRectangle& rhs)
    {
        min = rhs.min;
        max = rhs.max;
    }
    _HyperRectangle& operator= (const _HyperRectangle& rhs)
    {
        if(this == &rhs)
            return *this;
        min = rhs.min;
        max = rhs.max;
        return *this;
    }
    void create(_Examplar mx, _Examplar mn)
    {
        assert (mx.getDomDims() == mn.getDomDims());
        min = mn;
        max = mx;
    }

复制代码

};

复制代码

  对于整个数据集来说_HyperRectangle表示的就是对全体的统计范围信息,对部分数据集来说其表示的就是对部分数据的统计范围信息。还是以上篇中实例中的数据{(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)}为例,_HyperRectangle表示的统计范围如图1所示:

图1  _HyperRectangle表示的统计范围

    对于根节点(7,2),其所对应的空间范围是整个数据集,所以根节点(7,2)的_range_hr就是对整个数据集所有维度方向(此例即x,y方向)的数据范围统计得min = {dom_elt = (2,1),dom_dims = 2},max = {dom_elt = (9,7),dom_dims = 2};
    对于中间节点(5,4),其所对应的空间范围是根节点的左子树,所以节点(5,4)的_range_hr就是对整个数据集所有维度方向(此例即x,y方向)的数据范围统计得min = {dom_elt = (2,3),dom_dims = 2},max = {dom_elt = (5,7),dom_dims = 2};
    对于叶子节点(4,7),其所对应的空间范围是节点本身,所以节点(4,7)的_range_hr就是对整个数据集所有维度方向(此例即x,y方向)的 数据范围统计得min = {dom_elt = (4,7),dom_dims = 2},max = {dom_elt = (4,7),dom_dims = 2};

  最后再进行整个k-d tree结构的定义。代码如下:
复制代码

class KDTree    //k-d tree结构定义
{
public:
    KDTreeNode *_root;        //k-d tree的根节点
public:
    KDTree():_root(NULL){}
    void create(const ExamplarSet &exm_set);        //创建k-d tree,实际上调用createKDTree
    void destroy();                                 //销毁k-d tree,实际上调用destroyKDTree
    ~KDTree(){ destroyKDTree(_root); }
    std::pair<_Examplar, double> findNearest(_Examplar target);    //查找最近邻点函数,返回值是pair类型
                                                                   //实际是调用findNearest_i
    //查找距离在range范围内的近邻点,返回这样近邻点的个数,实际是调用findNearest_range
    int findNearest(_Examplar target, double range, std::vector<std::pair<_Examplar, double>> &res_nearest);
private:
    KDTreeNode* createKDTree(const ExamplarSet &exm_set);
    void destroyKDTree(KDTreeNode *root);
    std::pair<_Examplar, double> findNearest_i(KDTreeNode *root, _Examplar target);
    int findNearest_range(KDTreeNode *root, _Examplar target, double range, 
        std::vector<std::pair<_Examplar, double>> &res_nearest);

复制代码

  可见,整个k-d tree结构是由一系列KDTreeNode类的节点构成。整个k-d树的构建算法和基于k-d树的最邻近查找算法主要就是由createKDTree,findNearest_i以及findNearest_range这三个函数完成。代码分别如下:

    createKDTree

复制代码

//KDTree::是由于定义了KDTree的namespace
KDTree::KDTreeNode* KDTree::KDTree::createKDTree( const ExamplarSet &exm_set )
{
    if(exm_set.empty())
        return NULL;

    ExamplarSet exm_set_copy(exm_set);

    int dims = exm_set_copy.getDims();
    int size = exm_set_copy.getSize();

    //计算每个维的方差,选出方差值最大的维
    double var_max = -0.1; 
    double avg, var;
    int dim_max_var = -1;
    for(int i=0;i<dims;i++)
    {
        avg = 0;
        var = 0;
        //求某一维的总和
        for(int j=0;j<size;j++)
        {
            avg += exm_set_copy[j][i];
        }
        //求平均
        avg /= size;
        //求方差
        for(int j=0;j<size;j++)
        {
            var += ( exm_set_copy[j][i] - avg ) * 
                ( exm_set_copy[j][i] - avg );
        }
        var /= size;
        if(var > var_max)
        {
            var_max = var;
            dim_max_var = i;
        }
    }

    //确定节点的数据矢量
    _HyperRectangle hr = exm_set_copy.calculateRange();     //统计节点空间范围
    exm_set_copy.sortByDim(dim_max_var);                    //将所有数据向量按最大区分度方向排序
    int mid = size / 2;
    _Examplar exm_split = exm_set_copy.examplarAt(mid);     //取出排序结果的中间节点
    exm_set_copy.remove(mid);                               //将中间节点作为父(根)节点,所有将其从数据集中去除

    //确定左右节点
    ExamplarSet exm_set_left(0, exm_set_copy.getDims());
    ExamplarSet exm_set_right(0, exm_set_copy.getDims());
    exm_set_right.remove(0);

    int size_new = exm_set_copy.getSize();         //获得子数据空间大小
    for(int i=0;i<size_new;i++)                    //生成左右子节点
    {
        _Examplar temp = exm_set_copy[i];
        if( temp.dataAt(dim_max_var) < 
            exm_split.dataAt(dim_max_var) )
            exm_set_left.push_back(temp);
        else
            exm_set_right.push_back(temp);
    }

    KDTreeNode *pNewNode = new KDTreeNode(0, 0, 0, dim_max_var, exm_split, hr);
    pNewNode->_left_child = createKDTree(exm_set_left);        //递归调用生成左子树
    if(pNewNode->_left_child != NULL)                          //确认左子树父节点
        pNewNode->_left_child->_parent = pNewNode;            
    pNewNode->_right_child = createKDTree(exm_set_right);      //递归调用生成右子树
    if(pNewNode->_right_child != NULL)                         //确认右子树父节点
        pNewNode->_right_child->_parent = pNewNode;

    return pNewNode;    //最终返回k-d tree的根节点
}

复制代码

  整个createKDTree函数完全符合上篇中表2所述。注意其中统计节点空间范围calculateRange这一函数,其定义如下:
复制代码

KDTree::_HyperRectangle KDTree::ExamplarSet::calculateRange()
{
    assert(_size > 0);
    assert(_dims > 0);
    _Examplar mn(_dims);
    _Examplar mx(_dims);

    for(int j=0;j<_dims;j++)
    {
        mn.dataAt(j) = (*this)[0][j];    //初始化最小范围向量
        mx.dataAt(j) = (*this)[0][j];    //初始化最大范围向量
    }

    for(int i=1;i<_size;i++)            //统计数据集中每一个数据向量
    {
        for(int j=0;j<_dims;j++)
        {
            if( (*this)[i][j] < mn[j] )    //比较每一维,寻找最小值
                mn[j] = (*this)[i][j];
            if( (*this)[i][j] > mx[j] )    //比较每一维,寻找最大值
                mx[j] = (*this)[i][j];
        }
    }
    _HyperRectangle hr(mx, mn);

    return hr;    //返回一个_HyperRectangle结构
}

复制代码

    findNearest_i

复制代码

std::pair<KDTree::_Examplar, double> KDTree::KDTree::findNearest_i( KDTreeNode *root, _Examplar target )
{
    KDTreeNode *pSearch = root;

    //堆栈用于保存搜索路径
    std::vector<KDTreeNode*> search_path;

    _Examplar nearest;

    double max_dist;

    while(pSearch != NULL)                //首先通过二叉查找得到搜索路径
    {
        search_path.push_back(pSearch);
        int s = pSearch->splitDim();
        if(target[s] <= pSearch->getDomElt()[s])
        {
            pSearch = pSearch->_left_child;
        }
        else
        {
            pSearch = pSearch->_right_child;
        }
    }

    nearest = search_path.back()->getDomElt();        //取路径中最后的叶子节点为回溯前的最邻近点
    max_dist = Distance_exm(nearest, target);

    search_path.pop_back();

    //回溯搜索路径
    while(!search_path.empty())
    {
        KDTreeNode *pBack = search_path.back();
        search_path.pop_back();

        if( pBack->_left_child == NULL && pBack->_right_child == NULL)        //如果是叶子节点,就直接比较距离的大小
        {
            if( Distance_exm(nearest, target) > Distance_exm(pBack->getDomElt(), target) )
            {
                nearest = pBack->getDomElt();
                max_dist = Distance_exm(pBack->getDomElt(), target);
            }
        }
        else
        {
            int s = pBack->splitDim();
            if( abs(pBack->getDomElt()[s] - target[s]) < max_dist)    //以target为圆心,max_dist为半径的圆和分割面如果
            {                                                         //有交割,则需要进入另一边子空间搜索
                if( Distance_exm(nearest, target) > Distance_exm(pBack->getDomElt(), target) )
                {
                    nearest = pBack->getDomElt();
                    max_dist = Distance_exm(pBack->getDomElt(), target);
                }
                if(target[s] <= pBack->getDomElt()[s])    //如果target位于左子空间,就应进入右子空间
                    pSearch = pBack->_right_child;
                else
                    pSearch = pBack->_left_child;         //如果target位于右子空间,就应进入左子空间
                if(pSearch != NULL)
                    search_path.push_back(pSearch);       //将新的节点加入search_path中
            }
        }
    }

    std::pair<_Examplar, double> res(nearest, max_dist);

    return res;        //返回包含最邻近点和最近距离的pair
}

复制代码

    findNearest_range

复制代码

int KDTree::KDTree::findNearest_range( KDTreeNode *root, _Examplar target, double range, 
    std::vector<std::pair<_Examplar, double>> &res_nearest )
{
    if(root == NULL)
        return 0;
    double dist_sq, dx;
    int ret, added_res = 0;
    dist_sq = 0;
    dist_sq = Distance_exm(root->getDomElt(), target);    //计算搜索路径中每个节点和target的距离

    if(dist_sq <= range) {                   //将范围内的近邻添加到结果向量res_nearest中
        std::pair<_Examplar,double> temp(root->getDomElt(), dist_sq);
        res_nearest.push_back(temp);
        //结果个数+1
        added_res = 1;
    }

    dx = target[root->splitDim()] - root->getDomElt()[root->splitDim()];
    //左子树或右子树递归的查找
    ret = findNearest_range(dx <= 0.0 ? root->_left_child : root->_right_child, target, range, res_nearest);
    //当另外一边可能存在范围内的近邻
    if(ret >= 0 && fabs(dx) < range) {
        added_res += ret;
        ret = findNearest_range(dx <= 0.0 ? root->_right_child : root->_left_child, target, range, res_nearest);
    }

    added_res += ret;
    return added_res;        //最终返回范围内的近邻个数
}

你可能感兴趣的:(k-d tree代码解析)