近邻搜索之制高点树(VP-Tree)

引子

近邻搜索是一种很基础的又相当重要的操作,除了信息检索以外,还被广泛用于计算机视觉、机器学习等领域,如何快速有效的做近邻查询一直是一项热门的研究。较早提出的方法多基于空间划分(Space Partition),最具有代表性的如kd-tree(kdt),球树等。本篇将介绍基于空间划分方法中的一种,制高点树(Vantage Point Tree,vpt),最初在1993年提出,比kdt稍晚,提供了一个不一样的建树思路。

VPT结构

和kdt一样,vpt也是一类二叉树,不同的是在每个节点的划分策略。略微回顾一下kdt,它在每个节点选择一个维度,根据数据点在该维度上的大小将数据均分为二。而在vpt中,首先从节点中选择一个数据点(可随机选)作为制高点(vp),然后算出其它点到vp的距离大小,最后根据该距离大小将数据点均分为二。建树算法如下:

  1. 选择某数据点v作为vp
  2. 计算其它点{Xi}到v的距离{Di}
  3. 求出{Di}中值M,小于M的数据点分给左子树,大于M的数据点分给右子树
  4. 递归地建立左子树和右子树
这里提供一个简单的例子如图,框中为平面上的点,其中红框为选中的vp,根据其它点到vp的距离进行了子树划分。

近邻搜索之制高点树(VP-Tree)_第1张图片

VPT查询算法


vpt查询是 准确近邻查询,较适合范围查询,可方便扩展为k近邻查询。

进行近邻查询时,假定查询点为q,当前的制高点为v,距离中值为M,则有如下策略搜索到q点距离小于r的点集:

(1)  若 dist(q,v)+r≥M,递归地搜索右子树(球外区域)

(2)  若 dist(q,v)-r≤M,递归地搜索左子树(球内区域)

近邻搜索之制高点树(VP-Tree)_第2张图片
为了方便写公式,用图片文字来进行证明,其实就是简单的三角形不等式的应用。

简易实现代码


最后上点干货,一个简易c++实现如下:
#ifndef _VPTREE_HEADER_
#define _VPTREE_HEADER_

#include 
#include 
#include 
#include 
#include 
#include 
//#include "fnn.h"

template
class VpTree
{
public:
    VpTree() : _root(0) {}

    ~VpTree() {
        delete _root;
    }

    void create( const std::vector& items ) {
        delete _root;
        _items = items;
        _root = buildFromPoints(0, items.size());
    }

    void search( const T& target, int k, std::vector* results, 
        std::vector* distances) 
    {
        std::priority_queue heap;

        _tau = std::numeric_limits::max();
        search( _root, target, k, heap );

        results->clear(); distances->clear();

        while( !heap.empty() ) {
            results->push_back( _items[heap.top().index] );
            distances->push_back( heap.top().dist );
            heap.pop();
        }

        std::reverse( results->begin(), results->end() );
        std::reverse( distances->begin(), distances->end() );
		printf("vp search dist = %f\n",distances->at(0));
		brute(target);
    }

	void search(const T& target,std::vector* results,std::vector* distances){
        int idx;
		double min = 1.0e+10;
		for(int i=0;i<_items.size();i++){
			double dist = distance( _items[i], target );
			if(distpush_back(_items[idx]);
		distances->push_back(min);
	}

	int range_search(const T& target, double range, int *list, int &listnum){
		int hit = 0;
		for(int i=0;i<_items.size();i++){
			double dist = distance( _items[i], target );
			//debug here
			/*if(getId(_items[i])==4){
				printf("vp dist=%f range=%f\n",dist,range);
			}*/
			//-debug
			if(dist<=range){  //inside, need to check
				//list[listnum++] = getId(_items[i]);
				int id = getId(_items[i]);
				list[id] = 1;
				listnum++;
				hit++;
			}
		}
		/*_tau = range;
		rsearch( _root, target, hit, list, listnum);*/
		return hit;
	}

private:
    std::vector _items;
    double _tau;

    struct Node 
    {
        int index;
        double threshold;
        Node* left;
        Node* right;

        Node() :
            index(0), threshold(0.), left(0), right(0) {}

        ~Node() {
            delete left;
            delete right;
        }
    }* _root;

    struct HeapItem {
        HeapItem( int index, double dist) :
            index(index), dist(dist) {}
        int index;
        double dist;
        bool operator<( const HeapItem& o ) const {
            return dist < o.dist;   
        }
    };

    struct DistanceComparator
    {
        const T& item;
        DistanceComparator( const T& item ) : item(item) {}
        bool operator()(const T& a, const T& b) {
            return distance( item, a ) < distance( item, b );
        }
    };

    Node* buildFromPoints( int lower, int upper )
    {
        if ( upper == lower ) {
            return NULL;
        }

        Node* node = new Node();
        node->index = lower;

        if ( upper - lower > 1 ) {

            // choose an arbitrary point and move it to the start
            int i = (int)((double)rand() / RAND_MAX * (upper - lower - 1) ) + lower;
            std::swap( _items[lower], _items[i] );

            int median = ( upper + lower ) / 2;

            // partitian around the median distance
            std::nth_element( 
                _items.begin() + lower + 1, 
                _items.begin() + median,
                _items.begin() + upper,
                DistanceComparator( _items[lower] ));

            // what was the median?
            node->threshold = distance( _items[lower], _items[median] );

            node->index = lower;
            node->left = buildFromPoints( lower + 1, median );
            node->right = buildFromPoints( median, upper );
        }

        return node;
    }

	double brute(const T& target){
		double min = 1.0e+10;
		for(int i=0;i<_items.size();i++){
			double dist = distance( _items[i], target );
			if(distindex], target );
		if ( dist < _tau ) {
			counter++;
			//list[ listnum++ ] = getId(_items[node->index]);
			list[getId(_items[node->index])] = 1;
		}

		if ( node->left == NULL && node->right == NULL ) {
            return;
        }

        if ( dist < node->threshold ) {
            if ( dist - _tau <= node->threshold ) {
                rsearch( node->left, target, counter, list, listnum);
            }

            if ( dist + _tau >= node->threshold ) {
                rsearch( node->right, target, counter, list, listnum );
            }

        } 
		else {
            if ( dist + _tau >= node->threshold ) {
                rsearch( node->right, target, counter, list, listnum );
            }

            if ( dist - _tau <= node->threshold ) {
                rsearch( node->left, target, counter, list, listnum);
            }
        }
	}

    void search( Node* node, const T& target, int k,
                 std::priority_queue& heap )
    {
        if ( node == NULL ) return;

        double dist = distance( _items[node->index], target );
        //printf("dist=%g tau=%gn", dist, _tau );

        if ( dist < _tau ) {
            if ( heap.size() == k ) heap.pop();
            heap.push( HeapItem(node->index, dist) );
            if ( heap.size() == k ) _tau = heap.top().dist;
        }

        if ( node->left == NULL && node->right == NULL ) {
            return;
        }

        if ( dist < node->threshold ) {
            if ( dist - _tau <= node->threshold ) {
                search( node->left, target, k, heap );
            }

            if ( dist + _tau >= node->threshold ) {
                search( node->right, target, k, heap );
            }

        } else {
            if ( dist + _tau >= node->threshold ) {
                search( node->right, target, k, heap );
            }

            if ( dist - _tau <= node->threshold ) {
                search( node->left, target, k, heap );
            }
        }
    }
};

#endif

你可能感兴趣的:(近邻搜索之制高点树(VP-Tree))