K-D树主要是为了实现机器学习算法中的K近邻算法,单纯的K-D树只能实现最近邻,但是结合优先队列就可以实现K近邻了,这里只是把K-D树简单的实现了一下,经过简单测试,暂时没有发现重大bug。
#ifndef KDTREE_H
#define KDTREE_H
#include
#include
#include
#include
#include
#include
using ::std::vector;
using ::std::cout;
using ::std::endl;
namespace sx {
typedef float DataType;
typedef unsigned int UInt;
struct Feature {
vector data;
int id;
Feature() {}
Feature(const vector & d, int i)
: data(d), id(i) {}
} /* optional variable list */;
template
class KDTree {
public:
KDTree();
virtual ~KDTree();
KDTree(const KDTree & rhs);
const KDTree & operator = (const KDTree & rhs);
void Clean();
void Build(const vector & matrix_feature);
int FindNearestFeature(const Feature & target) const;
int FindNearestFeature(const Feature & target,
DataType & min_difference) const;
void Show() const;
private:
struct KDNode {
KDNode * left;
KDNode * right;
Feature feature;
int depth;
KDNode(const Feature & f, KDNode * lt, KDNode * rt, int d)
: feature(f), left(lt), right(rt), depth(d) {}
} /* optional variable list */;
KDNode * root_;
struct Comparator {
int index_comparator;
Comparator(int ix)
: index_comparator(ix) {}
bool operator () (const Feature & lhs, const Feature & rhs) {
return lhs.data[index_comparator] < rhs.data[index_comparator];
}
} /* optional variable list */;
KDNode * Clone(KDNode * t) const;
void Clean(KDNode * & t);
void SortFeature(vector & features, int index);
void Build(const vector & matrix_feature,
KDNode * & t, int depth);
DataType Feature2FeatureDifference(const Feature & f1,
const Feature & f2) const;
int FindNearestFeature(const Feature & target,
DataType & min_difference, KDNode * t) const;
void Show(KDNode * t) const;
};
template
KDTree::
KDTree()
: root_(NULL) {}
template
KDTree::
~KDTree() {
Clean();
}
template
KDTree::
KDTree(const KDTree & rhs) {
*this = rhs;
}
template
const KDTree & KDTree::
operator = (const KDTree & rhs) {
if (this != &rhs) {
Clean();
root_ = Clone(rhs.root_);
}
return *this;
}
template
void KDTree::
Clean() {
Clean(root_);
}
template
void KDTree::
Build(const vector & matrix_feature) {
if (matrix_feature.size() != 0) {
assert(matrix_feature[0].data.size() == K);
}
Build(matrix_feature, root_, 0);
}
template
int KDTree::
FindNearestFeature(const Feature & target) const {
DataType min_difference;
return FindNearestFeature(target, min_difference);
}
template
int KDTree::
FindNearestFeature(const Feature & target, DataType & min_difference) const {
min_difference = 10e8;
return FindNearestFeature(target, min_difference, root_);
}
template
void KDTree::
Show() const {
Show(root_);
return ;
}
template
typename KDTree::KDNode * KDTree::
Clone(KDNode * t) const {
if (NULL == t) {
return NULL;
}
return new KDNode(t->feature, t->left, t->right, t->depth);
}
template
void KDTree::
Clean(KDNode * & t) {
if (t != NULL) {
Clean(t->left);
Clean(t->right);
delete t;
}
t = NULL;
}
template
void KDTree::
SortFeature(vector & features, int index) {
sort(features.begin(), features.end(), Comparator(index));
}
template
void KDTree::
Build(const vector & matrix_feature, KDNode * & t, int depth) {
if (matrix_feature.size() == 0) {
t = NULL;
return ;
}
vector temp_feature = matrix_feature;
vector left_feature;
vector right_feature;
SortFeature(temp_feature, depth % K);
int length = (int)temp_feature.size();
int middle_position = length / 2;
t = new KDNode(temp_feature[middle_position], NULL, NULL, depth);
for (int i = 0; i < middle_position; ++i) {
left_feature.push_back(temp_feature[i]);
}
for (int i = middle_position + 1; i < length; ++i) {
right_feature.push_back(temp_feature[i]);
}
Build(left_feature, t->left, depth + 1);
Build(right_feature, t->right, depth + 1);
return ;
}
template
DataType KDTree::
Feature2FeatureDifference(const Feature & f1, const Feature & f2) const {
DataType diff = 0.0;
assert(f1.data.size() == f2.data.size());
for (int i = 0; i < (int)f1.data.size(); ++i) {
diff += (f1.data[i] - f2.data[i]) * (f1.data[i] - f2.data[i]);
}
return sqrt(diff);
}
template
int KDTree::
FindNearestFeature(const Feature & target, DataType & min_difference,
KDNode * t) const {
if (NULL == t) {
return -1;
}
DataType diff_parent = Feature2FeatureDifference(target, t->feature);
DataType diff_left = 10e8;
DataType diff_right = 10e8;
int result_parent = -1;
int result_left = -1;
int result_right = -1;
if (diff_parent < min_difference) {
min_difference = diff_parent;
result_parent = t->feature.id;
}
if (NULL == t->left && NULL == t->right) {
return result_parent;
}
if (NULL == t->left /* && t->right != NULL */) {
result_right = FindNearestFeature(target, diff_right, t->right);
if (diff_right < min_difference) {
min_difference = diff_right;
result_parent = result_right;
}
return result_parent;
}
if (NULL == t->right /* && t->left != NULL */) {
result_left = FindNearestFeature(target, diff_left, t->left);
if (diff_left < min_difference) {
min_difference = diff_left;
result_parent = result_left;
}
return result_parent;
}
int index_feature = t->depth % K;
DataType diff_boundary =
fabs(target.data[index_feature] - t->feature.data[index_feature]);
if (target.data[index_feature] < t->feature.data[index_feature]) {
result_left = FindNearestFeature(target, diff_left, t->left);
if (diff_left < min_difference) {
min_difference = diff_left;
result_parent = result_left;
}
if (diff_boundary <
Feature2FeatureDifference(target, t->left->feature)) {
result_right = FindNearestFeature(target, diff_right, t->right);
if (diff_right < min_difference) {
min_difference = diff_right;
result_parent = result_right;
}
}
} else {
result_right = FindNearestFeature(target, diff_right, t->right);
if (diff_right < min_difference) {
min_difference = diff_right;
result_parent = result_right;
}
if (diff_boundary <
Feature2FeatureDifference(target, t->right->feature)) {
result_left = FindNearestFeature(target, diff_left, t->left);
if (diff_left < min_difference) {
min_difference = diff_left;
result_parent = result_left;
}
}
}
return result_parent;
}
template
void KDTree::Show(KDNode * t) const {
cout << "ID: " << t->feature.id << endl;
cout << "Data: ";
for (int i = 0; i < (int)t->feature.data.size(); ++i) {
cout << t->feature.data[i] << " ";
}
cout << endl;
if (t->left != NULL) {
cout << "Left: " << t->feature.id << " -> " << t->left->feature.id << endl;
Show(t->left);
}
if (t->right != NULL) {
cout << "Right: " << t->feature.id << " -> " << t->right->feature.id << endl;
Show(t->right);
}
return ;
}
} /* sx */
#endif /* end of include guard: KDTREE_H */
下面是测试的main.cc文件
// =============================================================================
//
// Filename: main.cc
//
// Description: K-D Tree
//
// Version: 1.0
// Created: 04/11/2013 04:28:02 PM
// Revision: none
// Compiler: g++
//
// Author: Geek SphinX(Perstare et Praestare), [email protected]
// Organization: Hefei University of Technology
//
// =============================================================================
#include "KDTree.h"
#include
using namespace sx;
using namespace std;
int main(int argc, char *argv[]) {
KDTree<2> kdtree;
vector vf;
int idx = 0;
for (int i = 0; i < 5; ++i) {
for (int j = 0; j < 5; ++j) {
vector vd;
vd.push_back(i);
vd.push_back(j);
vf.push_back(Feature(vd, idx++));
}
}
kdtree.Build(vf);
kdtree.Show();
Feature target;
int n;
DataType x, y;
cin >> n;
while (n--) {
cin >> x >> y;
vector vd;
vd.push_back(x);
vd.push_back(y);
target = Feature(vd, 0);
DataType md;
int t = kdtree.FindNearestFeature(target, md);
cout << "Result is " << t << endl;
for (int i = 0; i < (int)vf[t].data.size(); ++i) {
cout << vf[t].data[i] << " ";
}
cout << endl;
cout << "Minimum Difference is " << md << endl;
}
return 0;
}