k近邻法的实现:kd树

最近在看李航博士写的《统计学习方法》一书,其中第三章讲述的k近邻法里面有一节是k近邻法的实现:kd树,关于最近邻法的原理,请大家查阅相关资料,作为一个初学者,觉得李航博士的书已经写的很好了,推荐大家阅读,本文的目的只是记录下代码实现。

一、所用到的数据结构描述

1.树节点,看代码:

template   //使用了C++的模板
struct kd_node         //kd树的节点
{
	T data;           //节点中的数据
	kd_node* l_child; //左孩子
	kd_node* r_child; //右孩子
	kd_node* p_father;//父节点
	int depth;        //表示节点的深度,默认值是0
	kd_node(T d,kd_node* l = NULL,kd_node* r = NULL,kd_node* f=NULL,int dep=0)//构造函数
	{
		data = d;
		l_child = l;
		r_child = r;
		p_father = f;
		depth = dep;
	}
	kd_node(int d)//构造函数
	{
		depth = d;
		l_child = NULL;
		r_child = NULL;
		p_father = NULL;
	} 
	kd_node()    //构造函数
	{
		l_child = NULL;
		r_child = NULL;
		p_father = NULL;
	}
};

2.kd树,看代码:

template 
class kdTree
{
private:
	kd_node >* nearest;//用来保存离目标节点最近的节点
public:

	kd_node >* root;      //根节点指针
	void createKdTree(vector > input)//根据输入数据创建kd树
	{
		root = new kd_node >();
		root->depth = 0;                     //根节点的深度为0
		this->split(input,this->root);       //调用划分函数,这是一个递归的函数
	}

	const T getMedian(const vector& a)             //获得一个向量的中位数
	{
		size_t n = a.size();
		vector tmp(a.begin(),a.end());
		sort(tmp.begin(),tmp.end());           //对a进行排序
		return tmp[n/2];                      //返回中位数
	}
	//下面的split函数是构造kd树的关键函数,它接收两个参数,第一个参数src是用来建立kd树的训练数据,第二个参数是指向一棵树的根节点的指针
	void split(const vector > src,kd_node >* &p)  //根据第split_index的值,将src划分为两个子vector
	{
		if(src.empty())//如果src为空,不用继续划分
		{
			p=NULL;   //p置为NULL,保证叶子节点的子节点为NULL
			return;
		}
		int split_index = p->depth%DIMENSION;            //用来划分的向量的维数的下标,DIMENSION是一个宏,最后贴完整代码时候可见,它表示数据的维数 
		vector median_element = getMedianElement(src,split_index);//获得中位数据点,此函数有点问题,会在后面详述
	        p->data = median_element;                                    //将根节点的数据置为src划分维数中位数所在的数据点
		T median = median_element[split_index];			     
		vector >l;                                         //用来保存划分后的左孩子
		vector >r;					    //用来保存划分后的右孩子  
		size_t n = src.size();
		for(size_t i = 0;i < n;i++)
		{
			vector v = src[i];
			if(v[split_index] < median) //小于中位数的放在左孩子里面
			{
				l.push_back(v);
			}
			if(v[split_index] > median)//大于中位数的放在右孩子里面
			{
				r.push_back(v);
			}
		}
		if(!l.empty())                 //如果左孩子不为空,则递归的划分左孩子
		{
		p->l_child = new kd_node >(p->depth+1);      //孩子节点的深度比父节点大一
		p->l_child->p_father = p;
		split(l,p->l_child);
		}
		if (!r.empty())               //如果右孩子不为空,则递归的划分右孩子
		{
			p->r_child = new kd_node >(p->depth+1);  //孩子节点的深度比父节点大一
			p->r_child->p_father = p;
			split(r,p->r_child);
		}
	}

	vector getMedianElement(vector > input,int split_index)//返回以split_index维划分的那个元素的整体
	{
		size_t n = input.size();
		vector tmp;
		tmp.resize(n);
		for(size_t i = 0;i < n;i++)
		{
			tmp[i] = input[i][split_index];
		}
		T median = this->getMedian(tmp);
		for(size_t i = 0;i < n;i++)
		{
			if(input[i][split_index]==median)
			{
				return input[i];
			}
		}
		vector result;
		return result;
	}

	void printTree(kd_node >* p)//递归的打印以p为根节点的树
	{
		if(p==NULL)         //递归结束条件
		{
			return;
		}
		
		vector data = p->data;
		size_t n = data.size();
		cout << "(";
		for(size_t i = 0;i < n;i++)
		{
			cout << data[i];
			if(i!=n-1)
			{
				cout << ",";
			}
		}
		cout << ")";
	
		printTree(p->l_child);//打印左子树
		printTree(p->r_child);//打印右子树
	}

	vector search(vector& target,int k=1,kd_node >*root=this->root)//搜寻树中target的k邻近点,默认是搜索最邻近点
	{
		//step1:找到包含目标节点的叶节点
		kd_node >*tmp_nearest = root;
		while(tmp_nearest!=NULL) //nearest==NULL说明到达了叶节点
		{
			int currentDimensition = tmp_nearest->depth%DIMENSION;
			if (target[currentDimensition] < tmp_nearest->data[currentDimensition])//小于则移动到左子树
			{
				tmp_nearest = tmp_nearest->l_child;
				if(tmp_nearest)
				{
					nearest = tmp_nearest;
				}
			}else                                                              //否则移动到右子树
			{
				tmp_nearest = tmp_nearest->r_child;
				if(tmp_nearest)
				{
					nearest = tmp_nearest;
				}
			}
		}
		cout << "the nearest leaf node is:" << endl;
		this->printNode(nearest->data);
	    _search(nearest,nearest->p_father,target,k);                                            //递归的回退到根节点
	     return nearest->data;
	}

	void _search(kd_node >*p,kd_node >*pFather,vector& target,int k)                         //递归的找寻节点p是不是最近点
	{
		if(isCloser(pFather->data,nearest->data,target))//当前点离目标点更近
		{
			nearest = pFather;
		}
		
		kd_node >*other_child = p==pFather->l_child?pFather->r_child:pFather->l_child;//获得另一子节点的指针
		if(other_child&&isCloser(other_child->data,nearest->data,target))//另一子节点离目标点更近,前提是另一子节点不为空
		{
			nearest = other_child;
			search(target,1,other_child);                  //以另一子节点作为搜索的起点
		}
		else                                              //另一子节点离目标点远
		{
			if(pFather==this->root)                       //到达根节点停止
			{
				return;
			}
			else
			{
		       _search(pFather,pFather->p_father,target,1);
			}
		}
	}

	bool isCloser(vector&first,vector&second,vector&target)//如果first离target比second近则返回true,否则返回false
	{
		return (distance(first,target) < distance(second,target));
	}

	double distance(vector&first,vector&second)               //计算两个向量的欧式距离
	{
		size_t n = first.size();
		assert(first.size()==second.size());
		double result = 0;
		for(size_t i = 0;i < n;i++)
		{
			result += (first[i]-second[i])*(first[i]-second[i]);
		}
		result = sqrt(result);
		return result;
	}

	void printNode(vector& node)  //打印一个节点
	{
		cout << "(";
		size_t n = node.size();
		for(size_t i = 0;i < n;i++)
		{
			cout << node[i];
			if(i != n-1)
			{
				cout << ",";
			}
		}
		cout << ")" << endl;
	}
};
二、相关.函数解释:

其中在构造kd树时候最重要的函数是split,这个函数接收两个参数,一个是用来创建kd树的训练数据,另一个是根节点,因为这个划分是一个递归的过程它的流程如下:

step1:判断src是否为空,如果为空说明当前节点中没有数据,将指针置为NULL并停止递归

step2:获得src中的中位数,以此数据将src划分为两部分,分别放在vector  l 和vector r中

step3:如果l、r不为空,则以l、r中的数据作为新的src、以当前节点对应于l和r的孩子节点作为新节点(孩子节点的深度要加一),递归的调用自己,从而完成对整个数据的划分。

在数据的搜索中,最重要的两个函数是search函数和_search函数。下面简单地说一下search算法的流程:

step1:在kd树中找到包含目标节点的叶子节点。此过程类似于二分查找,从根节点出发,一直向叶子节点方向移动,如果目标点当前维比切分点的值小则向左子树移动,否则向右子树移动,知道到达叶子节点。

step2:以此叶子节点作为“当前最近点”

step3:递归的向上回退,在每一个回退节点进行以下操作:

3.1如果该节点比当前最近点离目标点更近,则将当前最近点置换为该节点。

3.2检查此节点的兄弟节点是否比当前最近点更近,如果是从兄弟节点处在递归的执行搜索,否则继续回退

step4:递归到根节点时,搜索结束。最终的“当前最近点”即为目标点的最邻近点。

其他一些函数是工具函数不一一解释

三、最后的几点说明

1、虽说是k邻近法,但实际本文只是写了在k=1时,即最邻近法的代码,暂时没改到k近邻法,相信也不难,个人感觉只需在回退的时候不是简单地进行当前最近点的替换,在加入一些保存的步骤即可,然后取出保存结果的前k个即为k近邻。(此法本人没有验证)

2、文中提到的getMedianElement有问题,它返回的不一定是真正的中位数,只是满足等于划分维数数据的第一个点,如(1,2)(1,3)(1,7)这三个点以第一维作为划分的话,1,1,1的中位数是1,我取了第一个1,实际应该用第二个1里面的数据,我的函数将会返回(1,2)但返回(1,3)才是文中算法的目的。如果想要改进,可以考虑使用set,但其实本人觉得没有改的必要,权当你对训练数据进行了从新排序吧。

       3、本人能力有限,如有错误,欢迎大家指出。

你可能感兴趣的:(统计学习方法)