




这个类不是我们通常所认为的二叉树的结点类,而是表示binary space-partitioning tree 的结点。在每一个父结点处,表示将当前的空间分为两个子空间。这个分隔,不仅允许线性地分隔,同样也可以使用基于核函数的分隔。该类定义在

template <class InputT>
class BinaryTree
    typedef InputT value_type;

    BinaryTree(std::size_t size)
    : mep_parent(NULL)
    , mp_left(NULL)
    , mp_right(NULL)
    , mp_indexList(NULL)
    , m_size(size)
    , m_nodes(0)
    , m_threshold(0.0)
        SHARK_ASSERT(m_size > 0);

        mp_indexList = new std::size_t[m_size];

    virtual ~BinaryTree()
        if (mp_left != NULL) delete mp_left;
        if (mp_right != NULL) delete mp_right;
        if (mep_parent == NULL) delete [] mp_indexList;

    BinaryTree* parent()
    { return mep_parent; }

    const BinaryTree* parent() const
    { return mep_parent; }

    bool hasChildren() const
    { return (mp_left != NULL); }

    bool isLeaf() const
    { return (mp_left == NULL); }

    BinaryTree* left()
    { return mp_left; }

    const BinaryTree* left() const
    { return mp_left; }

    BinaryTree* right()
    { return mp_right; }

    const BinaryTree* right() const
    { return mp_right; }

    std::size_t size() const
    { return m_size; }

    std::size_t nodes() const
    { return m_nodes; }

    std::size_t index(std::size_t point)const{
        return mp_indexList[point];

    double distanceFromPlane(value_type const& point) const{
        return funct(point) - m_threshold;

    double threshold() const{
        return m_threshold;

    // 注意到,前面的left函数表示返回左孩子结点,而该函数的意思是
    // 查询结点是否位于左子空间内
    bool isLeft(value_type const& point) const
    { return (funct(point) < m_threshold); }

    bool isRight(value_type const& point) const
    { return (funct(point) >= m_threshold); }

    virtual AbstractKernelFunction const* kernel()const{
        //default is no kernel metric
        return NULL;

    // 计算查询点与当前空间距离下界的平方
    // 灵活使用三角不等式,可以使这个界更紧,搜索的速度也更快
    virtual double squaredDistanceLowerBound(value_type const& point) const = 0;

    BinaryTree(BinaryTree* parent, std::size_t* list, std::size_t size)
    : mep_parent(parent)
    , mp_left(NULL)
    , mp_right(NULL)
    , mp_indexList(list)
    , m_size(size)
    , m_nodes(0)

    // 计算查询点与当前分隔平面的距离
    virtual double funct(value_type const& point) const = 0;

    // 将结点中的数据分开。并返回分隔点。
    // Range1表示具体的数据值,Range2表示具体的数据点
    template<class Range1, class Range2>
    typename boost::range_iterator::type splitList (Range1& values, Range2& points){
        typedef typename boost::range_iterator::type iterator1;
        typedef typename boost::range_iterator::type iterator2;

        iterator1 valuesBegin = boost::begin(values);
        iterator1 valuesEnd = boost::end(values);

        std::pair splitpoint = partitionEqually(zipKeyValuePairs(values,points)).iterators();
        iterator1 valuesSplitpoint = splitpoint.first;
        iterator2 pointsSplitpoint = splitpoint.second;
        if (valuesSplitpoint == valuesEnd) {
            // partitioning failed, all values are equal :(
            m_threshold = *valuesBegin;
            return splitpoint.second;

        // We don't want the threshold to be the value of an element but always in between two of them.
        // This ensures that no point of the training set lies on the boundary. This leeds to more stable
        // results. So we use the mean of the found splitpoint and the nearest point on the other side
        // of the boundary.
        double maximum = *std::max_element(valuesBegin, valuesSplitpoint);
        m_threshold = 0.5*(maximum + *valuesSplitpoint);

        return pointsSplitpoint;

    BinaryTree* mep_parent;

    BinaryTree* mp_left;

    BinaryTree* mp_right;

    std::size_t* mp_indexList;

    std::size_t m_size;

    std::size_t m_nodes;

    double m_threshold;




class TreeConstruction
    : m_maxDepth(0xffffffff)
    , m_maxBucketSize(1)
    { }

    TreeConstruction(TreeConstruction const& other)
    : m_maxDepth(other.m_maxDepth)
    , m_maxBucketSize(other.m_maxBucketSize)
    { }

    TreeConstruction(unsigned int maxDepth, unsigned int maxBucketSize)
    : m_maxDepth(maxDepth ? maxDepth : 0xffffffff)
    , m_maxBucketSize(maxBucketSize ? maxBucketSize : 1)
    { }

    TreeConstruction nextDepthLevel() const
    { return TreeConstruction(m_maxDepth - 1, m_maxBucketSize); }

    unsigned int maxDepth() const
    { return m_maxDepth; }

    unsigned int maxBucketSize() const
    { return m_maxBucketSize; }

    unsigned int m_maxDepth;

    unsigned int m_maxBucketSize;



template <class InputT>
class KDTree : public BinaryTree
    typedef KDTree self_type;
    typedef BinaryTree base_type;

    KDTree(Data const& dataset, TreeConstruction tc = TreeConstruction())
    : base_type(dataset.numberOfElements())
    , m_cutDim(0xffffffff){
        typedef DataView const> PointSet;
        PointSet points(dataset);

        std::vector<typename boost::range_iterator::type> elements(m_size);


        for(std::size_t i = 0; i != m_size; ++i){
            mp_indexList[i] = elements[i].index();

    // 计算当前子空间在某一维度上与查询结点的距离的下界
    double lower(std::size_t dim) const{
        self_type* parent = static_cast(mep_parent);
        if (parent == NULL) return -1e100; // 如果是根结点

        if (parent->m_cutDim == dim && parent->mp_right == this)
            return parent->threshold();
            return parent->lower(dim);

    double upper(std::size_t dim) const{
        self_type* parent = static_cast(mep_parent);
        if (parent == NULL) return +1e100;

        if (parent->m_cutDim == dim && parent->mp_left == this) 
            return parent->threshold();
            return parent->upper(dim);

    // 计算当前子空间下,到查询结点距离的下界
    double squaredDistanceLowerBound(InputT const& reference) const
        double ret = 0.0;
        for (std::size_t d = 0; d != reference.size(); d++)
            double v = reference(d);
            double l = lower(d);
            double u = upper(d);
            if (v < l){
                ret += sqr(l-v);
            else if (v > u){
                ret += sqr(v-u);
        return ret;

    using base_type::mep_parent;
    using base_type::mp_left;
    using base_type::mp_right;
    using base_type::mp_indexList;
    using base_type::m_size;
    using base_type::m_nodes;

    // 构建非根节点
    KDTree(KDTree* parent, std::size_t* list, std::size_t size)
    : base_type(parent, list, size)
    , m_cutDim(0xffffffff)
    { }

    template<class Range>
    void buildTree(TreeConstruction tc, Range& points){
        typedef typename boost::range_iterator::type iterator;

        iterator begin = boost::begin(points);
        iterator end = boost::end(points);

        if (tc.maxDepth() == 0 || m_size <= tc.maxBucketSize()){
            m_nodes = 1; 

        m_cutDim = calculateCuttingDimension(points);

        // 获取当前数据中分隔维度上的所有值
        std::vector<double> distance(m_size);
        iterator point = begin;
        for(std::size_t i = 0; i != m_size; ++i,++point){
            distance[i] = get(**point,m_cutDim);

        // 分隔当前结点中的数据
        iterator split = this->splitList(distance,points);
        if (split == end){
            // 表示分隔失败,所以将该结点变为叶子结点
            m_nodes = 1;
        std::size_t leftSize = split-begin;

        // 构建当前结点的左右子结点
        mp_left = new KDTree(this, mp_indexList, leftSize);
        mp_right = new KDTree(this, mp_indexList + leftSize, m_size - leftSize);

        boost::iterator_range left(begin,split);
        boost::iterator_range right(split,end);
        ((KDTree*)mp_left)->buildTree(tc.nextDepthLevel(), left);
        ((KDTree*)mp_right)->buildTree(tc.nextDepthLevel(), right);
        m_nodes = 1 + mp_left->nodes() + mp_right->nodes();

    // 针对当前结点,计算分隔的维度
    template<class Range>
    std::size_t calculateCuttingDimension(Range const& points)const{
        typedef typename boost::range_iteratorconst>::type iterator;

        iterator begin = boost::begin(points);

        // 针对当前的数据,计算每一维数据的最大和最小值
        InputT L = **begin;
        InputT U = **begin;
        std::size_t dim = L.size();
        iterator point = begin;
        for (std::size_t i=1; i != m_size; ++i,++point){
            for (std::size_t d = 0; d != dim; d++){
                double v = (**point)[d];
                if (v < L[d]) L[d] = v;
                if (v > U[d]) U[d] = v;

        // 寻找范围覆盖最大的那一维度作为分隔维度
        std::size_t cutDim = 0;
        double extent = U[0] - L[0];
        for (std::size_t d = 1; d != dim; d++)
            double e = U[d] - L[d];
            if (e > extent)
                extent = e;
                cutDim = d;
        return cutDim;

    // 直接取出查询数据在分隔维度上的值,然后判断其是在左子空间还是右子空间中
    double funct(InputT const& reference) const{
        return reference[m_cutDim];

    std::size_t m_cutDim;



template<class InputType,class LabelType>
class AbstractNearestNeighbors{
    // 第一项表示返回的近邻与查询数据间的距离,第二项表示近邻的类标签
    typedef KeyValuePair<double,LabelType> DistancePair;
    typedef typename Batch::type BatchInputType;

    // 将输入数据的k个近邻以向量的形式返回
    virtual std::vector getNeighbors(BatchInputType const& batch, std::size_t k) const = 0;

    virtual LabeledDataconst& dataset()const = 0;

    virtual ~AbstractNearestNeighbors() {}



template<class InputType, class LabelType>
class TreeNearestNeighbors:public AbstractNearestNeighbors
    typedef AbstractNearestNeighbors base_type;

    typedef LabeledData Dataset;
    typedef BinaryTree Tree;
    typedef typename base_type::DistancePair DistancePair;
    typedef typename Batch::type BatchInputType;

    TreeNearestNeighbors(Dataset const& dataset, Tree const* tree)
    : m_dataset(dataset), m_inputs(dataset.inputs()), m_labels(dataset.labels()),mep_tree(tree)
    { }

    std::vector getNeighbors(BatchInputType const& patterns, std::size_t k)const{
        std::size_t numPoints = shark::size(patterns);
        std::vector results(k*numPoints);
        for(std::size_t p = 0; p != numPoints; ++p){
            IterativeNNQuery const> > query(mep_tree, m_inputs, get(patterns, p));
            // 查找输入数据的k个最近邻居,并将结果保存下来
            for(std::size_t i = 0; i != k; ++i){
                typename IterativeNNQuery const> >::result_type result =;
                results[i+p*k].value= m_labels[result.second]; 
        return results;

    LabeledDataconst& dataset()const {
        return m_dataset;

    Dataset const& m_dataset; // 训练数据集,感觉在这里并没有什么用
    DataView const> m_inputs;
    DataView const> m_labels;
    Tree const* mep_tree; // 构造好的kd树




class IterativeNNQuery
    typedef typename DataContainer::value_type value_type;
    typedef BinaryTree tree_type;
    typedef AbstractKernelFunction kernel_type;
    typedef std::pair<double, std::size_t> result_type;

    IterativeNNQuery(tree_type const* tree, DataContainer const& data, value_type const& point)
    : m_data(data)
    , m_reference(point)
    , m_nextIndex(0)
    , mp_trace(NULL)
    , mep_head(NULL)
    , m_squaredRadius(0.0)
    , m_neighbors(0)
        mp_trace = new TraceNode(tree, NULL, m_reference);
        TraceNode* tn = mp_trace;
        // 将轨迹扩展到kd树的叶结点上,
        while (tree->hasChildren())
            tn->createLeftNode(tree, m_data, m_reference);
            tn->createRightNode(tree, m_data, m_reference);
            bool left = tree->isLeft(m_reference);
            tn = (left ? tn->mep_left : tn->mep_right);
            tree = (left ? tree->left() : tree->right());
        mep_head = tn->mep_parent;
        insertIntoQueue((TraceLeaf*)tn); //先将当前最近的加入到队列中
        m_squaredRadius = mp_trace->squaredRadius(m_reference);

    ~IterativeNNQuery() {
        delete mp_trace;

    std::size_t neighbors() const {
        return m_neighbors;

    /// find and return the next nearest neighbor
    result_type next() {
        if (m_neighbors >= mp_trace->m_tree->size()) 
            throw SHARKEXCEPTION("[IterativeNNQuery::next] no more neighbors available");

        assert(! m_queue.empty());

        if (m_neighbors > 0){
            TraceLeaf& q = *m_queue.begin();
            // 当前叶结点中还有数据点没有被搜索完,因为只有一个数据,说明该结点没有被搜索过
            if (m_nextIndex < q.m_tree->size()){
                return getNextPoint(q);
        // 候选队列中已经没有结点,或是候选区域的结点距离太远
        if (m_queue.empty() || (*m_queue.begin()).m_squaredPtDistance > m_squaredRadius){
            // 向上回溯,不断地寻找还没有被搜索过的结点
            TraceNode* tn = mep_head;
            while (tn != NULL){
                if (tn->m_status == COMPLETE) mep_head = tn->mep_parent;
                tn = tn->mep_parent;

            m_squaredRadius = mp_trace->squaredRadius(m_reference);
        m_nextIndex = 0;
        return getNextPoint(*m_queue.begin());

    std::size_t queuesize() const{ 
        return m_queue.size();

    // 轨迹树结点在搜索过程中的状态
    enum Status
        NONE,            //结点中的数据都没有被加入到候选队列中
        PARTIAL,         //部分被加入
        COMPLETE,        //所有数据点都已被搜索过

    class TraceNode
        TraceNode(tree_type const* tree, TraceNode* parent, value_type const& reference)
        : m_tree(tree)
        , m_status(NONE)
        , mep_parent(parent)
        , mep_left(NULL)
        , mep_right(NULL)
        , m_squaredDistance(tree->squaredDistanceLowerBound(reference))
        { }

        virtual ~TraceNode()
            if (mep_left != NULL) delete mep_left;
            if (mep_right != NULL) delete mep_right;

        void createLeftNode(tree_type const* tree, DataContainer const& data, value_type const& reference){
            if (tree->left()->hasChildren())
                mep_left = new TraceNode(tree->left(), this, reference);
                // 如果左子结点是叶子结点,则需要调用建立叶子结点的类
                mep_left = new TraceLeaf(tree->left(), this, data, reference);
        void createRightNode(tree_type const* tree, DataContainer const& data, value_type const& reference){
            if (tree->right()->hasChildren())
                mep_right = new TraceNode(tree->right(), this, reference);
                mep_right = new TraceLeaf(tree->right(), this, data, reference);

        /// Compute the squared distance of the area not
        /// yet covered by the queue to the reference point.
        /// This is also referred to as the squared "radius"
        /// of the area covered by the queue (in fact, it is
        /// the radius of the largest sphere around the
        /// reference point that fits into the covered area).
        double squaredRadius(value_type const& ref) const{
            if (m_status == NONE) return m_squaredDistance;
            else if (m_status == PARTIAL)
                double l = mep_left->squaredRadius(ref);
                double r = mep_right->squaredRadius(ref);
                return std::min(l, r);
            else return 1e100;

        // 当前搜索到的位置对应的kd树
        tree_type const* m_tree;

        Status m_status;

        TraceNode* mep_parent;

        TraceNode* mep_left;

        TraceNode* mep_right;

        double m_squaredDistance;

    /// hook type for intrusive container
    typedef boost::intrusive::set_base_hook<> HookType;

    class TraceLeaf : public TraceNode, public HookType

        TraceLeaf(tree_type const* tree, TraceNode* parent, DataContainer const& data, value_type const& ref)
        : TraceNode(tree, parent, ref){
            // 判断kd树是否使用了核距离,并计算叶子结点中的数据点到查询点的距离
            if(tree->kernel() != NULL)
                m_squaredPtDistance = tree->kernel()->featureDistanceSqr(data[tree->index(0)], ref);
                m_squaredPtDistance = distanceSqr(data[tree->index(0)], ref);

        ~TraceLeaf() { }

        inline bool operator < (TraceLeaf const& rhs) const{
            if (m_squaredPtDistance == rhs.m_squaredPtDistance) 
                return (this->m_tree < rhs.m_tree);
                return (m_squaredPtDistance < rhs.m_squaredPtDistance);

        double m_squaredPtDistance;

    void insertIntoQueue(TraceLeaf* leaf){

        // 向上遍历轨迹树,修改结点的状态
        TraceNode* tn = leaf;
        tn->m_status = COMPLETE;
        while (true){
            TraceNode* par = tn->mep_parent;
            if (par == NULL) break;
            if (par->m_status == NONE){
                par->m_status = PARTIAL;
            else if (par->m_status == PARTIAL){
                // 如果左右子结点都已经被搜索完了,则将父结点也修改为搜索完的状态
                if (par->mep_left == tn){
                    if (par->mep_right->m_status == COMPLETE) par->m_status = COMPLETE;
                    else break;
                    if (par->mep_left->m_status == COMPLETE) par->m_status = COMPLETE;
                    else break;
            tn = par;

    // 将相应的信息构造成pair返回
    result_type getNextPoint(TraceLeaf const& leaf){
        double dist = std::sqrt(leaf.m_squaredPtDistance);
        std::size_t index = leaf.m_tree->index(m_nextIndex);
        return std::make_pair(dist,index);

    /// Recursively descend the node and enqueue
    /// all points in cells intersecting the
    /// current bounding sphere.
    void enqueue(TraceNode* tn){
        // 如果已经搜索过以该结点为根的子树,则返回
        if (tn->m_status == COMPLETE) return;
        if (! m_queue.empty() && tn->m_squaredDistance >= (*m_queue.begin()).m_squaredPtDistance) return;

        const tree_type* tree = tn->m_tree;
        // 如果还有结点的分支没有被搜索过,则需要扩展轨迹树
        if (tree->hasChildren()){
            if (tn->mep_left == NULL){
            if (tn->mep_right == NULL){

            if (tree->isLeft(m_reference))
            TraceLeaf* leaf = (TraceLeaf*)tn;

    // 待搜索结点的队列
    typedef boost::intrusive::rbtree QueueType;

    DataContainer const& m_data;

    value_type m_reference;

    QueueType m_queue;

    std::size_t m_nextIndex;

    // 在对kd树进行搜索的时候会构建一棵搜索的轨迹树
    // 轨迹树的根结点
    TraceNode* mp_trace;

    // 轨迹树当前搜索到的结点
    TraceNode* mep_head;

    double m_squaredRadius;

    std::size_t m_neighbors;



template <class InputType>
class NearestNeighborClassifier : public AbstractModelunsigned int>
    typedef AbstractNearestNeighborsunsigned int> NearestNeighbors;
    typedef AbstractModelunsigned int> base_type;
    typedef typename base_type::BatchInputType BatchInputType;
    typedef typename base_type::BatchOutputType BatchOutputType;

    enum DistanceWeights
        UNIFORM,                //不需要赋予权值,直接由多数决定
        ONE_OVER_DISTANCE,      //取距离的倒数作为权值

    NearestNeighborClassifier(NearestNeighbors const* algorithm, std::size_t neighbors = 3)
    : m_algorithm(algorithm)
    , m_classes(numberOfClasses(algorithm->dataset()))
    , m_neighbors(neighbors)
    , m_distanceWeights(UNIFORM)
    { }

    std::string name() const
    { return "NearestNeighborClassifier"; }

    std::size_t neighbors() const{
        return m_neighbors;

    void setNeighbors(std::size_t neighbors){

    DistanceWeights getDistanceWeightType() const
    { return m_distanceWeights; }

    void setDistanceWeightType(DistanceWeights dw)
    { m_distanceWeights = dw; }

    virtual RealVector parameterVector() const{
        RealVector parameters(1);
        parameters(0) = (double)m_neighbors;
        return parameters;

    virtual void setParameterVector(RealVector const& newParameters){
        SHARK_CHECK(newParameters.size() == 1,
            "[SoftNearestNeighborClassifier::setParameterVector] invalid number of parameters");
        //~ SHARK_CHECK((std::size_t)newParameters(0) == newParameters(0) && newParameters(0) >= 1.0,
            //~ "[SoftNearestNeighborClassifier::setParameterVector] invalid number of neighbors");
        m_neighbors = (std::size_t)newParameters(0);

    virtual std::size_t numberOfParameters() const{
        return 1;

    boost::shared_ptr createState()const{
        return boost::shared_ptr(new EmptyState());

    using base_type::eval;

    void eval(BatchInputType const& patterns, BatchOutputType& output, State& state)const{
        std::size_t numPatterns = shark::size(patterns);
        // 获取k个近邻
        std::vector<typename NearestNeighbors::DistancePair> neighbors = m_algorithm->getNeighbors(patterns,m_neighbors);


        for(std::size_t p = 0; p != numPatterns;++p){
            std::vector<double> histogram(m_classes, 0.0);
            for ( std::size_t k = 0; k != m_neighbors; ++k){
                // 计算每一个类别出现的次数,根据不同的权值选择,有不同的计算方式
                if (m_distanceWeights == UNIFORM) histogram[neighbors[p*m_neighbors+k].value]++;
                    double d = neighbors[p*m_neighbors+k].key;
                    if (d < 1e-100) histogram[neighbors[p*m_neighbors+k].value] += 1e100;
                    else histogram[neighbors[p*m_neighbors+k].value] += 1.0 / d;
            output(p) = static_cast<unsigned int>(std::max_element(histogram.begin(),histogram.end()) - histogram.begin());

    void read(InArchive& archive){
        archive & m_neighbors;
        archive & m_classes;

    void write(OutArchive& archive) const{
        archive & m_neighbors;
        archive & m_classes;

    NearestNeighbors const* m_algorithm;

    std::size_t m_classes;

    std::size_t m_neighbors;

    DistanceWeights m_distanceWeights;




using namespace shark;
using namespace std;

int main(int argc, char **argv) {
    if(argc < 2) {
        cerr << "usage: " << argv[0] << " (filename)" << endl;
    // read data
    ClassificationDataset data;
    try {
        importCSV(data, argv[1], LAST_COLUMN, ' ');
    catch (...) {
        cerr << "unable to read data from file " <<  argv[1] << endl;

    cout << "number of data points: " << data.numberOfElements()
         << " number of classes: " << numberOfClasses(data)
         << " input dimension: " << inputDimension(data) << endl;

    // split data into training and test set
    ClassificationDataset dataTest = splitAtElement(data, static_cast<std::size_t>(.5 * data.numberOfElements()));
    cout << "training data points: " << data.numberOfElements() << endl;
    cout << "test data points: " << dataTest.numberOfElements() << endl;

    //create a binary search tree and initialize the search algorithm - a fast tree search
    KDTree tree(data.inputs());
    TreeNearestNeighborsunsigned int> algorithm(data,&tree);
    //instantiate the classifier
    const unsigned int K = 1; // number of neighbors for kNN
    NearestNeighborClassifier KNN(&algorithm,K);

    // evaluate classifier
    ZeroOneLoss<unsigned int> loss;
    Data<unsigned int> prediction = KNN(data.inputs());
    cout << K << "-KNN on training set accuracy: " << 1. - loss.eval(data.labels(), prediction) << endl;
    prediction = KNN(dataTest.inputs());
    cout << K << "-KNN on test set accuracy:     " << 1. - loss.eval(dataTest.labels(), prediction) << endl;



LinearKernel<RealVector> kernel;
KHCTree<RealVector> tree(data.inputs(), &kernel);
TreeNearestNeighbors<RealVector, unsigned int> algorithm(data, &tree);
NearestNeighborClassifier<RealVector> KNN(&algorithm, K);



LinearKernel<> kernel;
SimpleNearestNeighbors<RealVector, unsigned int> algorithm(data, &kernel);
NearestNeighborClassifier<RealVector> KNN(&algorithm, K);

