近邻搜索是一种很基础的又相当重要的操作,除了信息检索以外,还被广泛用于计算机视觉、机器学习等领域,如何快速有效的做近邻查询一直是一项热门的研究。较早提出的方法多基于空间划分(Space Partition),最具有代表性的如kd-tree(kdt),球树等。本篇将介绍基于空间划分方法中的一种,制高点树(Vantage Point Tree,vpt),最初在1993年提出,比kdt稍晚,提供了一个不一样的建树思路。
和kdt一样,vpt也是一类二叉树,不同的是在每个节点的划分策略。略微回顾一下kdt,它在每个节点选择一个维度,根据数据点在该维度上的大小将数据均分为二。而在vpt中,首先从节点中选择一个数据点(可随机选)作为制高点(vp),然后算出其它点到vp的距离大小,最后根据该距离大小将数据点均分为二。建树算法如下:
进行近邻查询时,假定查询点为q,当前的制高点为v,距离中值为M,则有如下策略搜索到q点距离小于r的点集:
(1) 若 dist(q,v)+r≥M,递归地搜索右子树(球外区域)
(2) 若 dist(q,v)-r≤M,递归地搜索左子树(球内区域)
#ifndef _VPTREE_HEADER_
#define _VPTREE_HEADER_
#include
#include
#include
#include
#include
#include
//#include "fnn.h"
template
class VpTree
{
public:
VpTree() : _root(0) {}
~VpTree() {
delete _root;
}
void create( const std::vector& items ) {
delete _root;
_items = items;
_root = buildFromPoints(0, items.size());
}
void search( const T& target, int k, std::vector* results,
std::vector* distances)
{
std::priority_queue heap;
_tau = std::numeric_limits::max();
search( _root, target, k, heap );
results->clear(); distances->clear();
while( !heap.empty() ) {
results->push_back( _items[heap.top().index] );
distances->push_back( heap.top().dist );
heap.pop();
}
std::reverse( results->begin(), results->end() );
std::reverse( distances->begin(), distances->end() );
printf("vp search dist = %f\n",distances->at(0));
brute(target);
}
void search(const T& target,std::vector* results,std::vector* distances){
int idx;
double min = 1.0e+10;
for(int i=0;i<_items.size();i++){
double dist = distance( _items[i], target );
if(distpush_back(_items[idx]);
distances->push_back(min);
}
int range_search(const T& target, double range, int *list, int &listnum){
int hit = 0;
for(int i=0;i<_items.size();i++){
double dist = distance( _items[i], target );
//debug here
/*if(getId(_items[i])==4){
printf("vp dist=%f range=%f\n",dist,range);
}*/
//-debug
if(dist<=range){ //inside, need to check
//list[listnum++] = getId(_items[i]);
int id = getId(_items[i]);
list[id] = 1;
listnum++;
hit++;
}
}
/*_tau = range;
rsearch( _root, target, hit, list, listnum);*/
return hit;
}
private:
std::vector _items;
double _tau;
struct Node
{
int index;
double threshold;
Node* left;
Node* right;
Node() :
index(0), threshold(0.), left(0), right(0) {}
~Node() {
delete left;
delete right;
}
}* _root;
struct HeapItem {
HeapItem( int index, double dist) :
index(index), dist(dist) {}
int index;
double dist;
bool operator<( const HeapItem& o ) const {
return dist < o.dist;
}
};
struct DistanceComparator
{
const T& item;
DistanceComparator( const T& item ) : item(item) {}
bool operator()(const T& a, const T& b) {
return distance( item, a ) < distance( item, b );
}
};
Node* buildFromPoints( int lower, int upper )
{
if ( upper == lower ) {
return NULL;
}
Node* node = new Node();
node->index = lower;
if ( upper - lower > 1 ) {
// choose an arbitrary point and move it to the start
int i = (int)((double)rand() / RAND_MAX * (upper - lower - 1) ) + lower;
std::swap( _items[lower], _items[i] );
int median = ( upper + lower ) / 2;
// partitian around the median distance
std::nth_element(
_items.begin() + lower + 1,
_items.begin() + median,
_items.begin() + upper,
DistanceComparator( _items[lower] ));
// what was the median?
node->threshold = distance( _items[lower], _items[median] );
node->index = lower;
node->left = buildFromPoints( lower + 1, median );
node->right = buildFromPoints( median, upper );
}
return node;
}
double brute(const T& target){
double min = 1.0e+10;
for(int i=0;i<_items.size();i++){
double dist = distance( _items[i], target );
if(distindex], target );
if ( dist < _tau ) {
counter++;
//list[ listnum++ ] = getId(_items[node->index]);
list[getId(_items[node->index])] = 1;
}
if ( node->left == NULL && node->right == NULL ) {
return;
}
if ( dist < node->threshold ) {
if ( dist - _tau <= node->threshold ) {
rsearch( node->left, target, counter, list, listnum);
}
if ( dist + _tau >= node->threshold ) {
rsearch( node->right, target, counter, list, listnum );
}
}
else {
if ( dist + _tau >= node->threshold ) {
rsearch( node->right, target, counter, list, listnum );
}
if ( dist - _tau <= node->threshold ) {
rsearch( node->left, target, counter, list, listnum);
}
}
}
void search( Node* node, const T& target, int k,
std::priority_queue& heap )
{
if ( node == NULL ) return;
double dist = distance( _items[node->index], target );
//printf("dist=%g tau=%gn", dist, _tau );
if ( dist < _tau ) {
if ( heap.size() == k ) heap.pop();
heap.push( HeapItem(node->index, dist) );
if ( heap.size() == k ) _tau = heap.top().dist;
}
if ( node->left == NULL && node->right == NULL ) {
return;
}
if ( dist < node->threshold ) {
if ( dist - _tau <= node->threshold ) {
search( node->left, target, k, heap );
}
if ( dist + _tau >= node->threshold ) {
search( node->right, target, k, heap );
}
} else {
if ( dist + _tau >= node->threshold ) {
search( node->right, target, k, heap );
}
if ( dist - _tau <= node->threshold ) {
search( node->left, target, k, heap );
}
}
}
};
#endif