上一篇较详细地介绍了k-d树算法。本文来讲解具体的实现代码。
首先是一些数据结构的定义。我们先来定义单个数据,代码如下:
//单个数据向量结构定义 struct _Examplar { public: _Examplar():dom_dims(0){} //数据维度初始化为0 //带有完整的两个参数的constructor,这里const是为了保护原数据不被修改 _Examplar(const std::vector<double> elt, int dims)
{
if(dims > 0) { dom_elt = elt; dom_dims = dims; } else { dom_dims = 0; } }
_Examplar(int dims) //只含有维度信息的constructor { if(dims > 0) { dom_elt.resize(dims); dom_dims = dims; } else { dom_dims = 0; } } _Examplar(const _Examplar& rhs) //copy-constructor { if(rhs.dom_dims > 0) { dom_elt = rhs.dom_elt; dom_dims = rhs.dom_dims; } else { dom_dims = 0; } } _Examplar& operator=(const _Examplar& rhs) //重载"="运算符 { if(this == &rhs) return *this; releaseExamplarMem(); if(rhs.dom_dims > 0) { dom_elt = rhs.dom_elt; dom_dims = rhs.dom_dims; } return *this; } ~_Examplar() { } double& dataAt(int dim) //定义访问控制函数 { assert(dim < dom_dims); return dom_elt[dim]; } double& operator[](int dim) //重载"[]"运算符,实现下标访问 { return dataAt(dim); } const double& dataAt(int dim) const //定义只读访问函数 { assert(dim < dom_dims); return dom_elt[dim]; } const double& operator[](int dim) const //重载"[]"运算符,实现下标只读访问 { return dataAt(dim); } void create(int dims) //创建数据向量 { releaseExamplarMem(); if(dims > 0) { dom_elt.resize(dims); //控制数据向量维度 dom_dims = dims; } } int getDomDims() const //获得数据向量维度信息 { return dom_dims; } void setTo(double val) //数据向量初始化设置 { if(dom_dims > 0) { for(int i=0;i<dom_dims;i++) { dom_elt[i] = val; } } } private: void releaseExamplarMem() //清除现有数据向量 { dom_elt.clear(); dom_dims = 0; }
private: std::vector<double> dom_elt; //每个数据定义为一个double类型的向量 int dom_dims; //数据向量的维度 };
结构_Examplar定义了单个数据节点的结构,主要包含的信息有:1.数据向量本身;2.数据向量的维度。接下来定义一整个数据集的结构,代码如下:
//数据集结构定义 class ExamplarSet : public TrainData //整个数据集类,由一个抽象类TrainData派生 { private: //_Examplar *_ex_set; std::vector<_Examplar> _ex_set; //定义含有若干个_Examplar类数据向量的数据集 int _size; //数据集大小 int _dims; //数据集中每个数据向量的维度 public:
ExamplarSet():_size(0), _dims(0){} ExamplarSet(std::vector<_Examplar> ex_set, int size, int dims); ExamplarSet(int size, int dims); ExamplarSet(const ExamplarSet& rhs); ExamplarSet& operator=(const ExamplarSet& rhs); ~ExamplarSet(){} _Examplar& examplarAt(int idx) { assert(idx < _size); return _ex_set[idx]; } _Examplar& operator[](int idx) { return examplarAt(idx); } const _Examplar& examplarAt(int idx) const { assert(idx < _size); return _ex_set[idx]; } void create(int size, int dims); int getDims() const { return _dims;} int getSize() const { return _size;} _HyperRectangle calculateRange(); bool empty() const { return (_size == 0); }
void sortByDim(int dim); //按某个方向维的排序函数 bool remove(int idx); //去除数据集中排序后指定位置的数据向量 void push_back(const _Examplar& ex) //添加某个数据向量至数据集末尾 { _ex_set.push_back(ex); _size++; } int readData(char *strFilePath); //从文件读取数据集 private: void releaseExamplarSetMem() //清除现有数据集 { _ex_set.clear(); _size = 0; } };
类ExamplarSet定义了整个数据集的结构,其包含的主要信息有:1.含有若干个_Examplar类数据向量的数据集;2.数据集的大小;3.每个数据向量的维度。以上两个结构是整个算法两个基本的数据结构,这里的代码只是展示其主要包含的结构信息,详细的定义及函数实现代码请参看附件。
接下来就要定义k-d tree的结构。同样采用上述由点定义到集定义的思路,我们先来定义k-d tree中一个节点结构,代码如下:
//k-d tree节点结构定义 class KDTreeNode { private: int _split_dim; //该节点的最大区分度方向维 _Examplar _dom_elt; //该节点的数据向量 _HyperRectangle _range_hr; //表示数据范围的超矩形结构 public: KDTreeNode *_left_child, *_right_child, *_parent; //该节点的左右子树和父节点
public: KDTreeNode():_left_child(0), _right_child(0), _parent(0), _split_dim(0){} KDTreeNode(KDTreeNode *left_child, KDTreeNode *right_child, KDTreeNode *parent, int split_dim, _Examplar dom_elt, _HyperRectangle range_hr): _left_child(left_child), _right_child(right_child), _parent(parent), _split_dim(split_dim), _dom_elt(dom_elt), _range_hr(range_hr){} KDTreeNode(const KDTreeNode &rhs); KDTreeNode& operator=(const KDTreeNode &rhs); _Examplar& getDomElt() { return _dom_elt; } _HyperRectangle& getHyperRectangle(){ return _range_hr; } int& splitDim(){ return _split_dim; } void create(KDTreeNode *left_child, KDTreeNode *right_child, KDTreeNode *parent, int split_dim, _Examplar dom_elt, _HyperRectangle range_hr);
};
类KDTreeNode就是按照前一篇表1所述定义的。需要注意的是_HyperRectangle这一结构,它表示的就是这一节点所代表的空间范围Range,其定义如下:
struct _HyperRectangle //定义表示数据范围的超矩形结构 { _Examplar min; //统计数据集中所有数据向量每个维度上最小值组成的一个数据向量 _Examplar max; //统计数据集中所有数据向量每个维度上最大值组成的一个数据向量
_HyperRectangle() {} _HyperRectangle(_Examplar mx, _Examplar mn) { assert (mx.getDomDims() == mn.getDomDims()); min = mn; max = mx; } _HyperRectangle(const _HyperRectangle& rhs) { min = rhs.min; max = rhs.max; } _HyperRectangle& operator= (const _HyperRectangle& rhs) { if(this == &rhs) return *this; min = rhs.min; max = rhs.max; return *this; } void create(_Examplar mx, _Examplar mn) { assert (mx.getDomDims() == mn.getDomDims()); min = mn; max = mx; }
};
对于整个数据集来说_HyperRectangle表示的就是对全体的统计范围信息,对部分数据集来说其表示的就是对部分数据的统计范围信息。还是以上篇中实例中的数据{(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)}为例,_HyperRectangle表示的统计范围如图1所示:
图1 _HyperRectangle表示的统计范围
最后再进行整个k-d tree结构的定义。代码如下:
class KDTree //k-d tree结构定义 { public: KDTreeNode *_root; //k-d tree的根节点 public: KDTree():_root(NULL){} void create(const ExamplarSet &exm_set); //创建k-d tree,实际上调用createKDTree void destroy(); //销毁k-d tree,实际上调用destroyKDTree ~KDTree(){ destroyKDTree(_root); } std::pair<_Examplar, double> findNearest(_Examplar target); //查找最近邻点函数,返回值是pair类型 //实际是调用findNearest_i //查找距离在range范围内的近邻点,返回这样近邻点的个数,实际是调用findNearest_range int findNearest(_Examplar target, double range, std::vector<std::pair<_Examplar, double>> &res_nearest); private: KDTreeNode* createKDTree(const ExamplarSet &exm_set); void destroyKDTree(KDTreeNode *root); std::pair<_Examplar, double> findNearest_i(KDTreeNode *root, _Examplar target); int findNearest_range(KDTreeNode *root, _Examplar target, double range, std::vector<std::pair<_Examplar, double>> &res_nearest);
可见,整个k-d tree结构是由一系列KDTreeNode类的节点构成。整个k-d树的构建算法和基于k-d树的最邻近查找算法主要就是由createKDTree,findNearest_i以及findNearest_range这三个函数完成。代码分别如下:
//KDTree::是由于定义了KDTree的namespace KDTree::KDTreeNode* KDTree::KDTree::createKDTree( const ExamplarSet &exm_set ) { if(exm_set.empty()) return NULL; ExamplarSet exm_set_copy(exm_set); int dims = exm_set_copy.getDims(); int size = exm_set_copy.getSize(); //计算每个维的方差,选出方差值最大的维 double var_max = -0.1; double avg, var; int dim_max_var = -1; for(int i=0;i<dims;i++) { avg = 0; var = 0; //求某一维的总和 for(int j=0;j<size;j++) { avg += exm_set_copy[j][i]; } //求平均 avg /= size; //求方差 for(int j=0;j<size;j++) { var += ( exm_set_copy[j][i] - avg ) * ( exm_set_copy[j][i] - avg ); } var /= size; if(var > var_max) { var_max = var; dim_max_var = i; } } //确定节点的数据矢量 _HyperRectangle hr = exm_set_copy.calculateRange(); //统计节点空间范围 exm_set_copy.sortByDim(dim_max_var); //将所有数据向量按最大区分度方向排序 int mid = size / 2; _Examplar exm_split = exm_set_copy.examplarAt(mid); //取出排序结果的中间节点 exm_set_copy.remove(mid); //将中间节点作为父(根)节点,所有将其从数据集中去除 //确定左右节点 ExamplarSet exm_set_left(0, exm_set_copy.getDims()); ExamplarSet exm_set_right(0, exm_set_copy.getDims()); exm_set_right.remove(0); int size_new = exm_set_copy.getSize(); //获得子数据空间大小 for(int i=0;i<size_new;i++) //生成左右子节点 { _Examplar temp = exm_set_copy[i]; if( temp.dataAt(dim_max_var) < exm_split.dataAt(dim_max_var) ) exm_set_left.push_back(temp); else exm_set_right.push_back(temp); } KDTreeNode *pNewNode = new KDTreeNode(0, 0, 0, dim_max_var, exm_split, hr); pNewNode->_left_child = createKDTree(exm_set_left); //递归调用生成左子树 if(pNewNode->_left_child != NULL) //确认左子树父节点 pNewNode->_left_child->_parent = pNewNode; pNewNode->_right_child = createKDTree(exm_set_right); //递归调用生成右子树 if(pNewNode->_right_child != NULL) //确认右子树父节点 pNewNode->_right_child->_parent = pNewNode; return pNewNode; //最终返回k-d tree的根节点 }
整个createKDTree函数完全符合上篇中表2所述。注意其中统计节点空间范围calculateRange这一函数,其定义如下:
KDTree::_HyperRectangle KDTree::ExamplarSet::calculateRange() { assert(_size > 0); assert(_dims > 0); _Examplar mn(_dims); _Examplar mx(_dims); for(int j=0;j<_dims;j++) { mn.dataAt(j) = (*this)[0][j]; //初始化最小范围向量 mx.dataAt(j) = (*this)[0][j]; //初始化最大范围向量 } for(int i=1;i<_size;i++) //统计数据集中每一个数据向量 { for(int j=0;j<_dims;j++) { if( (*this)[i][j] < mn[j] ) //比较每一维,寻找最小值 mn[j] = (*this)[i][j]; if( (*this)[i][j] > mx[j] ) //比较每一维,寻找最大值 mx[j] = (*this)[i][j]; } } _HyperRectangle hr(mx, mn); return hr; //返回一个_HyperRectangle结构 }
std::pair<KDTree::_Examplar, double> KDTree::KDTree::findNearest_i( KDTreeNode *root, _Examplar target ) { KDTreeNode *pSearch = root; //堆栈用于保存搜索路径 std::vector<KDTreeNode*> search_path; _Examplar nearest; double max_dist; while(pSearch != NULL) //首先通过二叉查找得到搜索路径 { search_path.push_back(pSearch); int s = pSearch->splitDim(); if(target[s] <= pSearch->getDomElt()[s]) { pSearch = pSearch->_left_child; } else { pSearch = pSearch->_right_child; } } nearest = search_path.back()->getDomElt(); //取路径中最后的叶子节点为回溯前的最邻近点 max_dist = Distance_exm(nearest, target); search_path.pop_back(); //回溯搜索路径 while(!search_path.empty()) { KDTreeNode *pBack = search_path.back(); search_path.pop_back(); if( pBack->_left_child == NULL && pBack->_right_child == NULL) //如果是叶子节点,就直接比较距离的大小 { if( Distance_exm(nearest, target) > Distance_exm(pBack->getDomElt(), target) ) { nearest = pBack->getDomElt(); max_dist = Distance_exm(pBack->getDomElt(), target); } } else { int s = pBack->splitDim(); if( abs(pBack->getDomElt()[s] - target[s]) < max_dist) //以target为圆心,max_dist为半径的圆和分割面如果 { //有交割,则需要进入另一边子空间搜索 if( Distance_exm(nearest, target) > Distance_exm(pBack->getDomElt(), target) ) { nearest = pBack->getDomElt(); max_dist = Distance_exm(pBack->getDomElt(), target); } if(target[s] <= pBack->getDomElt()[s]) //如果target位于左子空间,就应进入右子空间 pSearch = pBack->_right_child; else pSearch = pBack->_left_child; //如果target位于右子空间,就应进入左子空间 if(pSearch != NULL) search_path.push_back(pSearch); //将新的节点加入search_path中 } } } std::pair<_Examplar, double> res(nearest, max_dist); return res; //返回包含最邻近点和最近距离的pair }
int KDTree::KDTree::findNearest_range( KDTreeNode *root, _Examplar target, double range, std::vector<std::pair<_Examplar, double>> &res_nearest ) { if(root == NULL) return 0; double dist_sq, dx; int ret, added_res = 0; dist_sq = 0; dist_sq = Distance_exm(root->getDomElt(), target); //计算搜索路径中每个节点和target的距离 if(dist_sq <= range) { //将范围内的近邻添加到结果向量res_nearest中 std::pair<_Examplar,double> temp(root->getDomElt(), dist_sq); res_nearest.push_back(temp); //结果个数+1 added_res = 1; } dx = target[root->splitDim()] - root->getDomElt()[root->splitDim()]; //左子树或右子树递归的查找 ret = findNearest_range(dx <= 0.0 ? root->_left_child : root->_right_child, target, range, res_nearest); //当另外一边可能存在范围内的近邻 if(ret >= 0 && fabs(dx) < range) { added_res += ret; ret = findNearest_range(dx <= 0.0 ? root->_right_child : root->_left_child, target, range, res_nearest); } added_res += ret; return added_res; //最终返回范围内的近邻个数 }
依然利用前述实例的数据来做测试,查找(2.1,3.1)和(2,4.5)两点的最近邻,并查找距离在4以内的所有近邻。程序运行结果如下:
图2 查找(2.1,3.1)的结果 图3 查找(2,4.5)的结果
附件:http://files.cnblogs.com/eyeszjwang/kdtree.rar
转载:http://www.cnblogs.com/eyeszjwang/articles/2432465.html