最近在看李航博士写的《统计学习方法》一书,其中第三章讲述的k近邻法里面有一节是k近邻法的实现:kd树,关于最近邻法的原理,请大家查阅相关资料,作为一个初学者,觉得李航博士的书已经写的很好了,推荐大家阅读,本文的目的只是记录下代码实现。
一、所用到的数据结构描述
1.树节点,看代码:
template //使用了C++的模板
struct kd_node //kd树的节点
{
T data; //节点中的数据
kd_node* l_child; //左孩子
kd_node* r_child; //右孩子
kd_node* p_father;//父节点
int depth; //表示节点的深度,默认值是0
kd_node(T d,kd_node* l = NULL,kd_node* r = NULL,kd_node* f=NULL,int dep=0)//构造函数
{
data = d;
l_child = l;
r_child = r;
p_father = f;
depth = dep;
}
kd_node(int d)//构造函数
{
depth = d;
l_child = NULL;
r_child = NULL;
p_father = NULL;
}
kd_node() //构造函数
{
l_child = NULL;
r_child = NULL;
p_father = NULL;
}
};
2.kd树,看代码:
template
class kdTree
{
private:
kd_node >* nearest;//用来保存离目标节点最近的节点
public:
kd_node >* root; //根节点指针
void createKdTree(vector > input)//根据输入数据创建kd树
{
root = new kd_node >();
root->depth = 0; //根节点的深度为0
this->split(input,this->root); //调用划分函数,这是一个递归的函数
}
const T getMedian(const vector& a) //获得一个向量的中位数
{
size_t n = a.size();
vector tmp(a.begin(),a.end());
sort(tmp.begin(),tmp.end()); //对a进行排序
return tmp[n/2]; //返回中位数
}
//下面的split函数是构造kd树的关键函数,它接收两个参数,第一个参数src是用来建立kd树的训练数据,第二个参数是指向一棵树的根节点的指针
void split(const vector > src,kd_node >* &p) //根据第split_index的值,将src划分为两个子vector
{
if(src.empty())//如果src为空,不用继续划分
{
p=NULL; //p置为NULL,保证叶子节点的子节点为NULL
return;
}
int split_index = p->depth%DIMENSION; //用来划分的向量的维数的下标,DIMENSION是一个宏,最后贴完整代码时候可见,它表示数据的维数
vector median_element = getMedianElement(src,split_index);//获得中位数据点,此函数有点问题,会在后面详述
p->data = median_element; //将根节点的数据置为src划分维数中位数所在的数据点
T median = median_element[split_index];
vector >l; //用来保存划分后的左孩子
vector >r; //用来保存划分后的右孩子
size_t n = src.size();
for(size_t i = 0;i < n;i++)
{
vector v = src[i];
if(v[split_index] < median) //小于中位数的放在左孩子里面
{
l.push_back(v);
}
if(v[split_index] > median)//大于中位数的放在右孩子里面
{
r.push_back(v);
}
}
if(!l.empty()) //如果左孩子不为空,则递归的划分左孩子
{
p->l_child = new kd_node >(p->depth+1); //孩子节点的深度比父节点大一
p->l_child->p_father = p;
split(l,p->l_child);
}
if (!r.empty()) //如果右孩子不为空,则递归的划分右孩子
{
p->r_child = new kd_node >(p->depth+1); //孩子节点的深度比父节点大一
p->r_child->p_father = p;
split(r,p->r_child);
}
}
vector getMedianElement(vector > input,int split_index)//返回以split_index维划分的那个元素的整体
{
size_t n = input.size();
vector tmp;
tmp.resize(n);
for(size_t i = 0;i < n;i++)
{
tmp[i] = input[i][split_index];
}
T median = this->getMedian(tmp);
for(size_t i = 0;i < n;i++)
{
if(input[i][split_index]==median)
{
return input[i];
}
}
vector result;
return result;
}
void printTree(kd_node >* p)//递归的打印以p为根节点的树
{
if(p==NULL) //递归结束条件
{
return;
}
vector data = p->data;
size_t n = data.size();
cout << "(";
for(size_t i = 0;i < n;i++)
{
cout << data[i];
if(i!=n-1)
{
cout << ",";
}
}
cout << ")";
printTree(p->l_child);//打印左子树
printTree(p->r_child);//打印右子树
}
vector search(vector& target,int k=1,kd_node >*root=this->root)//搜寻树中target的k邻近点,默认是搜索最邻近点
{
//step1:找到包含目标节点的叶节点
kd_node >*tmp_nearest = root;
while(tmp_nearest!=NULL) //nearest==NULL说明到达了叶节点
{
int currentDimensition = tmp_nearest->depth%DIMENSION;
if (target[currentDimensition] < tmp_nearest->data[currentDimensition])//小于则移动到左子树
{
tmp_nearest = tmp_nearest->l_child;
if(tmp_nearest)
{
nearest = tmp_nearest;
}
}else //否则移动到右子树
{
tmp_nearest = tmp_nearest->r_child;
if(tmp_nearest)
{
nearest = tmp_nearest;
}
}
}
cout << "the nearest leaf node is:" << endl;
this->printNode(nearest->data);
_search(nearest,nearest->p_father,target,k); //递归的回退到根节点
return nearest->data;
}
void _search(kd_node >*p,kd_node >*pFather,vector& target,int k) //递归的找寻节点p是不是最近点
{
if(isCloser(pFather->data,nearest->data,target))//当前点离目标点更近
{
nearest = pFather;
}
kd_node >*other_child = p==pFather->l_child?pFather->r_child:pFather->l_child;//获得另一子节点的指针
if(other_child&&isCloser(other_child->data,nearest->data,target))//另一子节点离目标点更近,前提是另一子节点不为空
{
nearest = other_child;
search(target,1,other_child); //以另一子节点作为搜索的起点
}
else //另一子节点离目标点远
{
if(pFather==this->root) //到达根节点停止
{
return;
}
else
{
_search(pFather,pFather->p_father,target,1);
}
}
}
bool isCloser(vector&first,vector&second,vector&target)//如果first离target比second近则返回true,否则返回false
{
return (distance(first,target) < distance(second,target));
}
double distance(vector&first,vector&second) //计算两个向量的欧式距离
{
size_t n = first.size();
assert(first.size()==second.size());
double result = 0;
for(size_t i = 0;i < n;i++)
{
result += (first[i]-second[i])*(first[i]-second[i]);
}
result = sqrt(result);
return result;
}
void printNode(vector& node) //打印一个节点
{
cout << "(";
size_t n = node.size();
for(size_t i = 0;i < n;i++)
{
cout << node[i];
if(i != n-1)
{
cout << ",";
}
}
cout << ")" << endl;
}
};
二、相关.函数解释:
其中在构造kd树时候最重要的函数是split,这个函数接收两个参数,一个是用来创建kd树的训练数据,另一个是根节点,因为这个划分是一个递归的过程它的流程如下:
step1:判断src是否为空,如果为空说明当前节点中没有数据,将指针置为NULL并停止递归
step2:获得src中的中位数,以此数据将src划分为两部分,分别放在vector l 和vector r中
step3:如果l、r不为空,则以l、r中的数据作为新的src、以当前节点对应于l和r的孩子节点作为新节点(孩子节点的深度要加一),递归的调用自己,从而完成对整个数据的划分。
在数据的搜索中,最重要的两个函数是search函数和_search函数。下面简单地说一下search算法的流程:
step1:在kd树中找到包含目标节点的叶子节点。此过程类似于二分查找,从根节点出发,一直向叶子节点方向移动,如果目标点当前维比切分点的值小则向左子树移动,否则向右子树移动,知道到达叶子节点。
step2:以此叶子节点作为“当前最近点”
step3:递归的向上回退,在每一个回退节点进行以下操作:
3.1如果该节点比当前最近点离目标点更近,则将当前最近点置换为该节点。
3.2检查此节点的兄弟节点是否比当前最近点更近,如果是从兄弟节点处在递归的执行搜索,否则继续回退
step4:递归到根节点时,搜索结束。最终的“当前最近点”即为目标点的最邻近点。
其他一些函数是工具函数不一一解释
三、最后的几点说明
1、虽说是k邻近法,但实际本文只是写了在k=1时,即最邻近法的代码,暂时没改到k近邻法,相信也不难,个人感觉只需在回退的时候不是简单地进行当前最近点的替换,在加入一些保存的步骤即可,然后取出保存结果的前k个即为k近邻。(此法本人没有验证)
2、文中提到的getMedianElement有问题,它返回的不一定是真正的中位数,只是满足等于划分维数数据的第一个点,如(1,2)(1,3)(1,7)这三个点以第一维作为划分的话,1,1,1的中位数是1,我取了第一个1,实际应该用第二个1里面的数据,我的函数将会返回(1,2)但返回(1,3)才是文中算法的目的。如果想要改进,可以考虑使用set,但其实本人觉得没有改的必要,权当你对训练数据进行了从新排序吧。
3、本人能力有限,如有错误,欢迎大家指出。