K-D树 C++实现

       K-D树主要是为了实现机器学习算法中的K近邻算法,单纯的K-D树只能实现最近邻,但是结合优先队列就可以实现K近邻了,这里只是把K-D树简单的实现了一下,经过简单测试,暂时没有发现重大bug。

#ifndef KDTREE_H
#define KDTREE_H


#include 
#include 
#include 
#include 
#include 
#include 
using ::std::vector;
using ::std::cout;
using ::std::endl;

namespace sx {

  typedef float DataType;
  typedef unsigned int UInt;

  struct Feature {
    vector data;
    int id;

    Feature() {}
    Feature(const vector & d, int i)
      : data(d), id(i) {}
  } /* optional variable list */;

  template 
    class KDTree {
    public:
      KDTree();
      virtual ~KDTree();
      KDTree(const KDTree & rhs);

      const KDTree & operator = (const KDTree & rhs);

      void Clean();
      void Build(const vector & matrix_feature);

      int FindNearestFeature(const Feature & target) const;
      int FindNearestFeature(const Feature & target,
                             DataType & min_difference) const;

      void Show() const;

    private:
      struct KDNode {
        KDNode * left;
        KDNode * right;
        Feature feature;
        int depth;

        KDNode(const Feature & f, KDNode * lt, KDNode * rt, int d)
          : feature(f), left(lt), right(rt), depth(d) {}
      } /* optional variable list */;

      KDNode * root_;

      struct Comparator {
        int index_comparator;

        Comparator(int ix)
          : index_comparator(ix) {}

        bool operator () (const Feature & lhs, const Feature & rhs) {
          return lhs.data[index_comparator] < rhs.data[index_comparator];
        }
      } /* optional variable list */;

      KDNode * Clone(KDNode * t) const;

      void Clean(KDNode * & t);
      void SortFeature(vector & features, int index);
      void Build(const vector & matrix_feature,
                 KDNode * & t, int depth);

      DataType Feature2FeatureDifference(const Feature & f1,
                                         const Feature & f2) const;
      int FindNearestFeature(const Feature & target,
                             DataType & min_difference, KDNode * t) const;

      void Show(KDNode * t) const;
    };

  template 
    KDTree::
    KDTree()
    : root_(NULL) {}

  template 
    KDTree::
    ~KDTree() {
      Clean();
    }

  template 
    KDTree::
    KDTree(const KDTree & rhs) {
      *this = rhs;
    }

  template 
  const KDTree & KDTree::
    operator = (const KDTree & rhs) {
      if (this != &rhs) {
        Clean();
        root_ = Clone(rhs.root_);
      }
      return *this;
    }

  template 
    void KDTree::
    Clean() {
      Clean(root_);
    }

  template 
    void KDTree::
    Build(const vector & matrix_feature) {
      if (matrix_feature.size() != 0) {
        assert(matrix_feature[0].data.size() == K);
      }
      Build(matrix_feature, root_, 0);
    }

  template 
    int KDTree::
    FindNearestFeature(const Feature & target) const {
      DataType min_difference;
      return FindNearestFeature(target, min_difference);
    }

  template 
    int KDTree::
    FindNearestFeature(const Feature & target, DataType & min_difference) const {
      min_difference = 10e8;
      return FindNearestFeature(target, min_difference, root_);
    }

  template 
    void KDTree::
    Show() const {
      Show(root_);
      return ;
    }


  template 
    typename KDTree::KDNode * KDTree::
    Clone(KDNode * t) const {
      if (NULL == t) {
        return NULL;
      }
      return new KDNode(t->feature, t->left, t->right, t->depth);
    }


  template 
    void KDTree::
    Clean(KDNode * & t) {
      if (t != NULL) {
        Clean(t->left);
        Clean(t->right);
        delete t;
      }
      t = NULL;
    }

  template 
    void KDTree::
    SortFeature(vector & features, int index) {
      sort(features.begin(), features.end(), Comparator(index));
    }

  template 
    void KDTree::
    Build(const vector & matrix_feature, KDNode * & t, int depth) {
      if (matrix_feature.size() == 0) {
        t = NULL;
        return ;
      }
      vector temp_feature = matrix_feature;
      vector left_feature;
      vector right_feature;

      SortFeature(temp_feature, depth % K);

      int length = (int)temp_feature.size();
      int middle_position = length / 2;

      t = new KDNode(temp_feature[middle_position], NULL, NULL, depth);

      for (int i = 0; i < middle_position; ++i) {
        left_feature.push_back(temp_feature[i]);
      }

      for (int i = middle_position + 1; i < length; ++i) {
        right_feature.push_back(temp_feature[i]);
      }

      Build(left_feature, t->left, depth + 1);
      Build(right_feature, t->right, depth + 1);

      return ;
    }

  template 
    DataType KDTree::
    Feature2FeatureDifference(const Feature & f1, const Feature & f2) const {
      DataType diff = 0.0;
      assert(f1.data.size() == f2.data.size());
      for (int i = 0; i < (int)f1.data.size(); ++i) {
        diff += (f1.data[i] - f2.data[i]) * (f1.data[i] - f2.data[i]);
      }
      return sqrt(diff);
    }

  template 
    int KDTree::
    FindNearestFeature(const Feature & target, DataType & min_difference,
                       KDNode * t) const {
      if (NULL == t) {
        return -1;
      }

      DataType diff_parent = Feature2FeatureDifference(target, t->feature);
      DataType diff_left = 10e8;
      DataType diff_right = 10e8;

      int result_parent = -1;
      int result_left = -1;
      int result_right = -1;

      if (diff_parent < min_difference) {
        min_difference = diff_parent;
        result_parent = t->feature.id;
      }

      if (NULL == t->left && NULL == t->right) {
        return result_parent;
      }

      if (NULL == t->left /* && t->right != NULL */) {
        result_right = FindNearestFeature(target, diff_right, t->right);
        if (diff_right < min_difference) {
          min_difference = diff_right;
          result_parent = result_right;
        }
        return result_parent;
      }

      if (NULL == t->right /* && t->left != NULL */) {
        result_left = FindNearestFeature(target, diff_left, t->left);
        if (diff_left < min_difference) {
          min_difference = diff_left;
          result_parent = result_left;
        }
        return result_parent;
      }

      int index_feature = t->depth % K;
      DataType diff_boundary =
        fabs(target.data[index_feature] - t->feature.data[index_feature]);

      if (target.data[index_feature] < t->feature.data[index_feature]) {
        result_left = FindNearestFeature(target, diff_left, t->left);
        if (diff_left < min_difference) {
          min_difference = diff_left;
          result_parent = result_left;
        }
        if (diff_boundary <
            Feature2FeatureDifference(target, t->left->feature)) {
          result_right = FindNearestFeature(target, diff_right, t->right);
          if (diff_right < min_difference) {
            min_difference = diff_right;
            result_parent = result_right;
          }
        }
      } else {
        result_right = FindNearestFeature(target, diff_right, t->right);
        if (diff_right < min_difference) {
          min_difference = diff_right;
          result_parent = result_right;
        }
        if (diff_boundary <
            Feature2FeatureDifference(target, t->right->feature)) {
          result_left = FindNearestFeature(target, diff_left, t->left);
          if (diff_left < min_difference) {
            min_difference = diff_left;
            result_parent = result_left;
          }
        }
      }
      return result_parent;
    }

  template 
    void KDTree::Show(KDNode * t) const {
      cout << "ID: " << t->feature.id << endl;
      cout << "Data: ";
      for (int i = 0; i < (int)t->feature.data.size(); ++i) {
        cout << t->feature.data[i] << " ";
      }
      cout << endl;
      if (t->left != NULL) {
        cout << "Left: " << t->feature.id << " -> " << t->left->feature.id << endl;
        Show(t->left);
      }
      if (t->right != NULL) {
        cout << "Right: " << t->feature.id << " -> " << t->right->feature.id << endl;
        Show(t->right);
      }
      return ;
    }

} /* sx */



#endif /* end of include guard: KDTREE_H */

      下面是测试的main.cc文件

// =============================================================================
//
//       Filename:  main.cc
//
//    Description:  K-D Tree
//
//        Version:  1.0
//        Created:  04/11/2013 04:28:02 PM
//       Revision:  none
//       Compiler:  g++
//
//         Author:  Geek SphinX(Perstare et Praestare), [email protected]
//   Organization:  Hefei University of Technology
//
// =============================================================================

#include "KDTree.h"
#include 
using namespace sx;
using namespace std;

int main(int argc, char *argv[]) {
  KDTree<2> kdtree;
  vector vf;
  int idx = 0;
  for (int i = 0; i < 5; ++i) {
    for (int j = 0; j < 5; ++j) {
      vector vd;
      vd.push_back(i);
      vd.push_back(j);
      vf.push_back(Feature(vd, idx++));
    }
  }
  kdtree.Build(vf);
  kdtree.Show();
  Feature target;
  int n;
  DataType x, y;
  cin >> n;
  while (n--) {
    cin >> x >> y;
    vector vd;
    vd.push_back(x);
    vd.push_back(y);
    target = Feature(vd, 0);
    DataType md;
    int t = kdtree.FindNearestFeature(target, md);
    cout << "Result is " << t << endl;
    for (int i = 0; i < (int)vf[t].data.size(); ++i) {
      cout << vf[t].data[i] << " ";
    }
    cout << endl;
    cout << "Minimum Difference is " << md << endl;
  }
  return 0;
}


你可能感兴趣的:(Machine,Learning)