Shark源码分析(十):KNN算法

Shark源码分析(十):KNN算法

关于这个算法,我之前已经有博客详细介绍过。虽然说这个算法看上去非常的简单,但是在搜索k个最近邻居数据点时,还是非常具有技巧性的。这里还是有必要再次强调一下。如果输入数据的维度不高,可以使用树形结构(kd树)来加快查找的速度。如果输入的维度较高,则利用树型结构的速度与计算两两数据间距离的速度并不会有太大的差别。之后我们要介绍的代码也是利用kd树来组织的。

在计算距离时,不仅可以选择欧几里得距离,同样可以选择基于核函数的距离。同样地,也有基于核函数距离的kd树。

BinaryTree类

这个类不是我们通常所认为的二叉树的结点类,而是表示binary space-partitioning tree 的结点。在每一个父结点处,表示将当前的空间分为两个子空间。这个分隔,不仅允许线性地分隔,同样也可以使用基于核函数的分隔。该类定义在

template <class InputT>
class BinaryTree
{
public:
    typedef InputT value_type;

    BinaryTree(std::size_t size)
    : mep_parent(NULL)
    , mp_left(NULL)
    , mp_right(NULL)
    , mp_indexList(NULL)
    , m_size(size)
    , m_nodes(0)
    , m_threshold(0.0)
    {
        SHARK_ASSERT(m_size > 0);

        mp_indexList = new std::size_t[m_size];
        boost::iota(boost::make_iterator_range(mp_indexList,mp_indexList+m_size),0);
    }

    virtual ~BinaryTree()
    {
        if (mp_left != NULL) delete mp_left;
        if (mp_right != NULL) delete mp_right;
        if (mep_parent == NULL) delete [] mp_indexList;
    }

    BinaryTree* parent()
    { return mep_parent; }

    const BinaryTree* parent() const
    { return mep_parent; }

    bool hasChildren() const
    { return (mp_left != NULL); }

    bool isLeaf() const
    { return (mp_left == NULL); }

    BinaryTree* left()
    { return mp_left; }

    const BinaryTree* left() const
    { return mp_left; }

    BinaryTree* right()
    { return mp_right; }

    const BinaryTree* right() const
    { return mp_right; }

    std::size_t size() const
    { return m_size; }

    std::size_t nodes() const
    { return m_nodes; }

    std::size_t index(std::size_t point)const{
        return mp_indexList[point];
    }

    double distanceFromPlane(value_type const& point) const{
        return funct(point) - m_threshold;
    }

    double threshold() const{
        return m_threshold;
    }

    // 注意到,前面的left函数表示返回左孩子结点,而该函数的意思是
    // 查询结点是否位于左子空间内
    bool isLeft(value_type const& point) const
    { return (funct(point) < m_threshold); }

    bool isRight(value_type const& point) const
    { return (funct(point) >= m_threshold); }

    //如果计算距离时使用的是核函数,则返回核函数的对象
    virtual AbstractKernelFunction const* kernel()const{
        //default is no kernel metric
        return NULL;
    }

    // 计算查询点与当前空间距离下界的平方
    // 灵活使用三角不等式,可以使这个界更紧,搜索的速度也更快
    virtual double squaredDistanceLowerBound(value_type const& point) const = 0;

protected:
    BinaryTree(BinaryTree* parent, std::size_t* list, std::size_t size)
    : mep_parent(parent)
    , mp_left(NULL)
    , mp_right(NULL)
    , mp_indexList(list)
    , m_size(size)
    , m_nodes(0)
    {}

    // 计算查询点与当前分隔平面的距离
    virtual double funct(value_type const& point) const = 0;

    // 将结点中的数据分开。并返回分隔点。
    // Range1表示具体的数据值,Range2表示具体的数据点
    template<class Range1, class Range2>
    typename boost::range_iterator::type splitList (Range1& values, Range2& points){
        typedef typename boost::range_iterator::type iterator1;
        typedef typename boost::range_iterator::type iterator2;

        iterator1 valuesBegin = boost::begin(values);
        iterator1 valuesEnd = boost::end(values);

        //partitionEqually函数是将整个range划分为大小尽可能相等的两部分
        std::pair splitpoint = partitionEqually(zipKeyValuePairs(values,points)).iterators();
        iterator1 valuesSplitpoint = splitpoint.first;
        iterator2 pointsSplitpoint = splitpoint.second;
        if (valuesSplitpoint == valuesEnd) {
            // partitioning failed, all values are equal :(
            m_threshold = *valuesBegin;
            return splitpoint.second;
        }

        // We don't want the threshold to be the value of an element but always in between two of them.
        // This ensures that no point of the training set lies on the boundary. This leeds to more stable
        // results. So we use the mean of the found splitpoint and the nearest point on the other side
        // of the boundary.
        double maximum = *std::max_element(valuesBegin, valuesSplitpoint);
        m_threshold = 0.5*(maximum + *valuesSplitpoint);

        return pointsSplitpoint;
    }

    //父结点指针
    BinaryTree* mep_parent;

    //左孩子结点指针
    BinaryTree* mp_left;

    //右孩子结点指针
    BinaryTree* mp_right;

    //存储当前结点中数据类标签的列表
    std::size_t* mp_indexList;

    //当前结点中数据的个数
    std::size_t m_size;

    //以当前结点为根节点的子树的结点个数
    std::size_t m_nodes;

    //分隔空间的阈值
    double m_threshold;

};

TreeConstruction类

这个类表示的是树构造的停止条件,停止条件可以是树的高度,或是叶子结点中包含数据的最小个数。该文件的定义位置与BinaryTree是一样的。

class TreeConstruction
{
public:
    TreeConstruction()
    : m_maxDepth(0xffffffff)
    , m_maxBucketSize(1)
    { }

    TreeConstruction(TreeConstruction const& other)
    : m_maxDepth(other.m_maxDepth)
    , m_maxBucketSize(other.m_maxBucketSize)
    { }

    TreeConstruction(unsigned int maxDepth, unsigned int maxBucketSize)
    : m_maxDepth(maxDepth ? maxDepth : 0xffffffff)
    , m_maxBucketSize(maxBucketSize ? maxBucketSize : 1)
    { }

    //使树的高度限制减1
    TreeConstruction nextDepthLevel() const
    { return TreeConstruction(m_maxDepth - 1, m_maxBucketSize); }

    unsigned int maxDepth() const
    { return m_maxDepth; }

    unsigned int maxBucketSize() const
    { return m_maxBucketSize; }

protected:
    //树的最大深度
    unsigned int m_maxDepth;

    //叶子就诶点钟所含数据的最小个数
    unsigned int m_maxBucketSize;
};

KDTree类

该类定义在中。

template <class InputT>
class KDTree : public BinaryTree
{
    typedef KDTree self_type;
    typedef BinaryTree base_type;
public:

    KDTree(Data const& dataset, TreeConstruction tc = TreeConstruction())
    : base_type(dataset.numberOfElements())
    , m_cutDim(0xffffffff){
        typedef DataView const> PointSet;
        PointSet points(dataset);

        std::vector<typename boost::range_iterator::type> elements(m_size);
        boost::iota(elements,boost::begin(points));

        buildTree(tc,elements);

        //记录结点中数据点的类标签,该工作只在根结点中进行
        for(std::size_t i = 0; i != m_size; ++i){
            mp_indexList[i] = elements[i].index();
        }
    }

    // 计算当前子空间在某一维度上与查询结点的距离的下界
    double lower(std::size_t dim) const{
        self_type* parent = static_cast(mep_parent);
        if (parent == NULL) return -1e100; // 如果是根结点

        //维度是父结点的分隔维度,且是父结点的右子结点
        if (parent->m_cutDim == dim && parent->mp_right == this)
            return parent->threshold();
        else
            return parent->lower(dim);
    }

    double upper(std::size_t dim) const{
        self_type* parent = static_cast(mep_parent);
        if (parent == NULL) return +1e100;

        if (parent->m_cutDim == dim && parent->mp_left == this) 
            return parent->threshold();
        else 
            return parent->upper(dim);
    }

    // 计算当前子空间下,到查询结点距离的下界
    double squaredDistanceLowerBound(InputT const& reference) const
    {
        double ret = 0.0;
        for (std::size_t d = 0; d != reference.size(); d++)
        {
            double v = reference(d);
            double l = lower(d);
            double u = upper(d);
            if (v < l){
                ret += sqr(l-v);
            }
            else if (v > u){
                ret += sqr(v-u);
            }
        }
        return ret;
    }

protected:
    using base_type::mep_parent;
    using base_type::mp_left;
    using base_type::mp_right;
    using base_type::mp_indexList;
    using base_type::m_size;
    using base_type::m_nodes;

    // 构建非根节点
    KDTree(KDTree* parent, std::size_t* list, std::size_t size)
    : base_type(parent, list, size)
    , m_cutDim(0xffffffff)
    { }

    template<class Range>
    void buildTree(TreeConstruction tc, Range& points){
        typedef typename boost::range_iterator::type iterator;

        iterator begin = boost::begin(points);
        iterator end = boost::end(points);

        //如果到达树构造的停止条件,将该结点设置为叶结点
        if (tc.maxDepth() == 0 || m_size <= tc.maxBucketSize()){
            m_nodes = 1; 
            return; 
        }

        m_cutDim = calculateCuttingDimension(points);

        // 获取当前数据中分隔维度上的所有值
        std::vector<double> distance(m_size);
        iterator point = begin;
        for(std::size_t i = 0; i != m_size; ++i,++point){
            distance[i] = get(**point,m_cutDim);
        }

        // 分隔当前结点中的数据
        iterator split = this->splitList(distance,points);
        if (split == end){
            // 表示分隔失败,所以将该结点变为叶子结点
            m_nodes = 1;
            return; 
        }
        std::size_t leftSize = split-begin;

        // 构建当前结点的左右子结点
        mp_left = new KDTree(this, mp_indexList, leftSize);
        mp_right = new KDTree(this, mp_indexList + leftSize, m_size - leftSize);

        boost::iterator_range left(begin,split);
        boost::iterator_range right(split,end);
        ((KDTree*)mp_left)->buildTree(tc.nextDepthLevel(), left);
        ((KDTree*)mp_right)->buildTree(tc.nextDepthLevel(), right);
        m_nodes = 1 + mp_left->nodes() + mp_right->nodes();
    }

    // 针对当前结点,计算分隔的维度
    template<class Range>
    std::size_t calculateCuttingDimension(Range const& points)const{
        typedef typename boost::range_iteratorconst>::type iterator;

        iterator begin = boost::begin(points);

        // 针对当前的数据,计算每一维数据的最大和最小值
        InputT L = **begin;
        InputT U = **begin;
        std::size_t dim = L.size();
        iterator point = begin;
        ++point;
        for (std::size_t i=1; i != m_size; ++i,++point){
            for (std::size_t d = 0; d != dim; d++){
                double v = (**point)[d];
                if (v < L[d]) L[d] = v;
                if (v > U[d]) U[d] = v;
            }
        }

        // 寻找范围覆盖最大的那一维度作为分隔维度
        std::size_t cutDim = 0;
        double extent = U[0] - L[0];
        for (std::size_t d = 1; d != dim; d++)
        {
            double e = U[d] - L[d];
            if (e > extent)
            {
                extent = e;
                cutDim = d;
            }
        }
        return cutDim;
    }

    // 直接取出查询数据在分隔维度上的值,然后判断其是在左子空间还是右子空间中
    double funct(InputT const& reference) const{
        return reference[m_cutDim];
    }

    //在该结点分隔子空间的维度
    std::size_t m_cutDim;
};

AbstractNearestNeighbors类

这个类是所有近邻算法的基类。定义在

template<class InputType,class LabelType>
class AbstractNearestNeighbors{
public:
    // 第一项表示返回的近邻与查询数据间的距离,第二项表示近邻的类标签
    typedef KeyValuePair<double,LabelType> DistancePair;
    typedef typename Batch::type BatchInputType;

    // 将输入数据的k个近邻以向量的形式返回
    virtual std::vector getNeighbors(BatchInputType const& batch, std::size_t k) const = 0;

    virtual LabeledDataconst& dataset()const = 0;

    virtual ~AbstractNearestNeighbors() {}
};

TreeNearestNeighbor类

该类是knn的算法类,利用kd树来查找输入数据的k个最近邻居。该类定义在

template<class InputType, class LabelType>
class TreeNearestNeighbors:public AbstractNearestNeighbors
{
private:
    typedef AbstractNearestNeighbors base_type;

public:
    typedef LabeledData Dataset;
    typedef BinaryTree Tree;
    typedef typename base_type::DistancePair DistancePair;
    typedef typename Batch::type BatchInputType;

    TreeNearestNeighbors(Dataset const& dataset, Tree const* tree)
    : m_dataset(dataset), m_inputs(dataset.inputs()), m_labels(dataset.labels()),mep_tree(tree)
    { }

    std::vector getNeighbors(BatchInputType const& patterns, std::size_t k)const{
        std::size_t numPoints = shark::size(patterns);
        std::vector results(k*numPoints);
        for(std::size_t p = 0; p != numPoints; ++p){
            IterativeNNQuery const> > query(mep_tree, m_inputs, get(patterns, p));
            // 查找输入数据的k个最近邻居,并将结果保存下来
            for(std::size_t i = 0; i != k; ++i){
                typename IterativeNNQuery const> >::result_type result = query.next();
                results[i+p*k].key=result.first;
                results[i+p*k].value= m_labels[result.second]; 
            }
        }
        return results;
    }

    LabeledDataconst& dataset()const {
        return m_dataset;
    }

private:
    Dataset const& m_dataset; // 训练数据集,感觉在这里并没有什么用
    DataView const> m_inputs;
    DataView const> m_labels;
    Tree const* mep_tree; // 构造好的kd树

};

IterativeNNQuery类

在TreeNearestNeighbor类的代码中也看到了,对于kd树的查询工作主要还是由IterativeNNQuery类来完成的。它允许以迭代的方式查询给定数据点的近邻。比如第一个是最近的,第二个是第二近的。在构建kd树的时候,叶结点中最少数据点数要被设置为1。叶结点中可以含有多个相同的数据点。训练数据要存放在能够进行随机存取的容器中,这样能提高查询的速度。该类的定义位置与TreeNearestNeighbor类是一样的。

template 
class IterativeNNQuery
{
public:
    typedef typename DataContainer::value_type value_type;
    typedef BinaryTree tree_type;
    typedef AbstractKernelFunction kernel_type;
    typedef std::pair<double, std::size_t> result_type;

    IterativeNNQuery(tree_type const* tree, DataContainer const& data, value_type const& point)
    : m_data(data)
    , m_reference(point)
    , m_nextIndex(0)
    , mp_trace(NULL)
    , mep_head(NULL)
    , m_squaredRadius(0.0)
    , m_neighbors(0)
    {
        mp_trace = new TraceNode(tree, NULL, m_reference);
        TraceNode* tn = mp_trace;
        // 将轨迹扩展到kd树的叶结点上,
        while (tree->hasChildren())
        {
            tn->createLeftNode(tree, m_data, m_reference);
            tn->createRightNode(tree, m_data, m_reference);
            bool left = tree->isLeft(m_reference);
            tn = (left ? tn->mep_left : tn->mep_right);
            tree = (left ? tree->left() : tree->right());
        }
        mep_head = tn->mep_parent;
        insertIntoQueue((TraceLeaf*)tn); //先将当前最近的加入到队列中
        m_squaredRadius = mp_trace->squaredRadius(m_reference);
    }

    ~IterativeNNQuery() {
        m_queue.clear();
        delete mp_trace;
    }

    std::size_t neighbors() const {
        return m_neighbors;
    }

    /// find and return the next nearest neighbor
    result_type next() {
        if (m_neighbors >= mp_trace->m_tree->size()) 
            throw SHARKEXCEPTION("[IterativeNNQuery::next] no more neighbors available");

        assert(! m_queue.empty());

        if (m_neighbors > 0){
            TraceLeaf& q = *m_queue.begin();
            // 当前叶结点中还有数据点没有被搜索完,因为只有一个数据,说明该结点没有被搜索过
            if (m_nextIndex < q.m_tree->size()){
                //这里应该将m_neighbors加一,但看了之后的代码就知道了
                //外部循环也有变量来控制它
                return getNextPoint(q);
            }
            else
                m_queue.erase(q);
        }
        //为候选集队列扩充结点
        // 候选队列中已经没有结点,或是候选区域的结点距离太远
        if (m_queue.empty() || (*m_queue.begin()).m_squaredPtDistance > m_squaredRadius){
            // 向上回溯,不断地寻找还没有被搜索过的结点
            TraceNode* tn = mep_head;
            while (tn != NULL){
                enqueue(tn);
                if (tn->m_status == COMPLETE) mep_head = tn->mep_parent;
                tn = tn->mep_parent;
            }

            m_squaredRadius = mp_trace->squaredRadius(m_reference);
        }
        m_nextIndex = 0;
        ++m_neighbors;
        return getNextPoint(*m_queue.begin());
    }

    std::size_t queuesize() const{ 
        return m_queue.size();
    }

private:
    // 轨迹树结点在搜索过程中的状态
    enum Status
    {
        NONE,            //结点中的数据都没有被加入到候选队列中
        PARTIAL,         //部分被加入
        COMPLETE,        //所有数据点都已被搜索过
    };

    //轨迹树是在搜索的过程中被建立的,只需要覆盖那些在搜索过程中可能会遇到的结点
    class TraceNode
    {
    public:
        TraceNode(tree_type const* tree, TraceNode* parent, value_type const& reference)
        : m_tree(tree)
        , m_status(NONE)
        , mep_parent(parent)
        , mep_left(NULL)
        , mep_right(NULL)
        , m_squaredDistance(tree->squaredDistanceLowerBound(reference))
        { }

        virtual ~TraceNode()
        {
            if (mep_left != NULL) delete mep_left;
            if (mep_right != NULL) delete mep_right;
        }

        void createLeftNode(tree_type const* tree, DataContainer const& data, value_type const& reference){
            if (tree->left()->hasChildren())
                mep_left = new TraceNode(tree->left(), this, reference);
            else
                // 如果左子结点是叶子结点,则需要调用建立叶子结点的类
                mep_left = new TraceLeaf(tree->left(), this, data, reference);
        }
        void createRightNode(tree_type const* tree, DataContainer const& data, value_type const& reference){
            if (tree->right()->hasChildren())
                mep_right = new TraceNode(tree->right(), this, reference);
            else
                mep_right = new TraceLeaf(tree->right(), this, data, reference);
        }

        /// Compute the squared distance of the area not
        /// yet covered by the queue to the reference point.
        /// This is also referred to as the squared "radius"
        /// of the area covered by the queue (in fact, it is
        /// the radius of the largest sphere around the
        /// reference point that fits into the covered area).
        double squaredRadius(value_type const& ref) const{
            if (m_status == NONE) return m_squaredDistance;
            else if (m_status == PARTIAL)
            {
                double l = mep_left->squaredRadius(ref);
                double r = mep_right->squaredRadius(ref);
                return std::min(l, r);
            }
            else return 1e100;
        }

        // 当前搜索到的位置对应的kd树
        tree_type const* m_tree;

        //结点的状态
        Status m_status;

        //轨迹树的父结点指针
        TraceNode* mep_parent;

        //轨迹树的左儿子结点指针
        TraceNode* mep_left;

        //轨迹树的右儿子结点指针
        TraceNode* mep_right;

        //当前子空间到搜索点距离下界的平方
        double m_squaredDistance;
    };

    /// hook type for intrusive container
    typedef boost::intrusive::set_base_hook<> HookType;

    class TraceLeaf : public TraceNode, public HookType
    {
    public:

        TraceLeaf(tree_type const* tree, TraceNode* parent, DataContainer const& data, value_type const& ref)
        : TraceNode(tree, parent, ref){
            // 判断kd树是否使用了核距离,并计算叶子结点中的数据点到查询点的距离
            if(tree->kernel() != NULL)
                m_squaredPtDistance = tree->kernel()->featureDistanceSqr(data[tree->index(0)], ref);
            else
                m_squaredPtDistance = distanceSqr(data[tree->index(0)], ref);
        }

        ~TraceLeaf() { }

        inline bool operator < (TraceLeaf const& rhs) const{
            if (m_squaredPtDistance == rhs.m_squaredPtDistance) 
                return (this->m_tree < rhs.m_tree);
            else
                return (m_squaredPtDistance < rhs.m_squaredPtDistance);
        }

        double m_squaredPtDistance;
    };

    //将叶子结点加入到当前的搜索队列中
    void insertIntoQueue(TraceLeaf* leaf){
        //注意到候选队列是红黑树,在插入到树中之后,会对顺序进行调整
        m_queue.insert_unique(*leaf);

        // 向上遍历轨迹树,修改结点的状态
        TraceNode* tn = leaf;
        tn->m_status = COMPLETE;
        while (true){
            TraceNode* par = tn->mep_parent;
            if (par == NULL) break;
            if (par->m_status == NONE){
                par->m_status = PARTIAL;
                break;
            }
            else if (par->m_status == PARTIAL){
                // 如果左右子结点都已经被搜索完了,则将父结点也修改为搜索完的状态
                if (par->mep_left == tn){
                    if (par->mep_right->m_status == COMPLETE) par->m_status = COMPLETE;
                    else break;
                }
                else{
                    if (par->mep_left->m_status == COMPLETE) par->m_status = COMPLETE;
                    else break;
                }
            }
            tn = par;
        }
    }

    // 将相应的信息构造成pair返回
    result_type getNextPoint(TraceLeaf const& leaf){
        double dist = std::sqrt(leaf.m_squaredPtDistance);
        std::size_t index = leaf.m_tree->index(m_nextIndex);
        ++m_nextIndex;
        return std::make_pair(dist,index);
    }

    /// Recursively descend the node and enqueue
    /// all points in cells intersecting the
    /// current bounding sphere.
    void enqueue(TraceNode* tn){
        // 如果已经搜索过以该结点为根的子树,则返回
        if (tn->m_status == COMPLETE) return;
        if (! m_queue.empty() && tn->m_squaredDistance >= (*m_queue.begin()).m_squaredPtDistance) return;

        const tree_type* tree = tn->m_tree;
        // 如果还有结点的分支没有被搜索过,则需要扩展轨迹树
        if (tree->hasChildren()){
            if (tn->mep_left == NULL){
                tn->createLeftNode(tree,m_data,m_reference);
            }
            if (tn->mep_right == NULL){
                tn->createRightNode(tree,m_data,m_reference);
            }

            //若查询结点位于左子空间中,则先搜索左子空间中的结点
            if (tree->isLeft(m_reference))
            {
                enqueue(tn->mep_left);
                enqueue(tn->mep_right);
            }
            else
            {
                enqueue(tn->mep_right);
                enqueue(tn->mep_left);
            }
        }
        else
        {
            TraceLeaf* leaf = (TraceLeaf*)tn;
            insertIntoQueue(leaf);
        }
    }

    // 待搜索结点的队列
    typedef boost::intrusive::rbtree QueueType;

    //训练数据
    DataContainer const& m_data;

    //待搜索的结点
    value_type m_reference;

    QueueType m_queue;

    //当前叶结点中下一个还没有被搜索到的结点的下标
    std::size_t m_nextIndex;

    // 在对kd树进行搜索的时候会构建一棵搜索的轨迹树
    // 轨迹树的根结点
    TraceNode* mp_trace;

    // 轨迹树当前搜索到的结点
    TraceNode* mep_head;

    //当前覆盖区域的半径
    double m_squaredRadius;

    //已经寻找到的邻居的个数
    std::size_t m_neighbors;
};

NearestNeighborClassifier类

该类是利用所查询到的k个近邻来对输入数据进行分类的类。定义在

template <class InputType>
class NearestNeighborClassifier : public AbstractModelunsigned int>
{
public:
    typedef AbstractNearestNeighborsunsigned int> NearestNeighbors;
    typedef AbstractModelunsigned int> base_type;
    typedef typename base_type::BatchInputType BatchInputType;
    typedef typename base_type::BatchOutputType BatchOutputType;

    //在对类标签进行决策时,为距离赋上一个怎样的权值
    enum DistanceWeights
    {
        UNIFORM,                //不需要赋予权值,直接由多数决定
        ONE_OVER_DISTANCE,      //取距离的倒数作为权值
    };

    NearestNeighborClassifier(NearestNeighbors const* algorithm, std::size_t neighbors = 3)
    : m_algorithm(algorithm)
    , m_classes(numberOfClasses(algorithm->dataset()))
    , m_neighbors(neighbors)
    , m_distanceWeights(UNIFORM)
    { }

    std::string name() const
    { return "NearestNeighborClassifier"; }

    std::size_t neighbors() const{
        return m_neighbors;
    }

    void setNeighbors(std::size_t neighbors){
        m_neighbors=neighbors;
    }

    DistanceWeights getDistanceWeightType() const
    { return m_distanceWeights; }

    void setDistanceWeightType(DistanceWeights dw)
    { m_distanceWeights = dw; }

    virtual RealVector parameterVector() const{
        RealVector parameters(1);
        parameters(0) = (double)m_neighbors;
        return parameters;
    }

    virtual void setParameterVector(RealVector const& newParameters){
        SHARK_CHECK(newParameters.size() == 1,
            "[SoftNearestNeighborClassifier::setParameterVector] invalid number of parameters");
        //~ SHARK_CHECK((std::size_t)newParameters(0) == newParameters(0) && newParameters(0) >= 1.0,
            //~ "[SoftNearestNeighborClassifier::setParameterVector] invalid number of neighbors");
        m_neighbors = (std::size_t)newParameters(0);
    }

    virtual std::size_t numberOfParameters() const{
        return 1;
    }

    boost::shared_ptr createState()const{
        return boost::shared_ptr(new EmptyState());
    }

    using base_type::eval;

    void eval(BatchInputType const& patterns, BatchOutputType& output, State& state)const{
        std::size_t numPatterns = shark::size(patterns);
        // 获取k个近邻
        std::vector<typename NearestNeighbors::DistancePair> neighbors = m_algorithm->getNeighbors(patterns,m_neighbors);

        output.resize(numPatterns);
        output.clear();

        for(std::size_t p = 0; p != numPatterns;++p){
            std::vector<double> histogram(m_classes, 0.0);
            for ( std::size_t k = 0; k != m_neighbors; ++k){
                // 计算每一个类别出现的次数,根据不同的权值选择,有不同的计算方式
                if (m_distanceWeights == UNIFORM) histogram[neighbors[p*m_neighbors+k].value]++;
                else
                {
                    double d = neighbors[p*m_neighbors+k].key;
                    if (d < 1e-100) histogram[neighbors[p*m_neighbors+k].value] += 1e100;
                    else histogram[neighbors[p*m_neighbors+k].value] += 1.0 / d;
                }
            }
            output(p) = static_cast<unsigned int>(std::max_element(histogram.begin(),histogram.end()) - histogram.begin());
        }
    }

    void read(InArchive& archive){
        archive & m_neighbors;
        archive & m_classes;
    }

    void write(OutArchive& archive) const{
        archive & m_neighbors;
        archive & m_classes;
    }

protected:
    NearestNeighbors const* m_algorithm;

    std::size_t m_classes;

    //knn中的k值
    std::size_t m_neighbors;

    DistanceWeights m_distanceWeights;
};

具体的实例

还是和原来一样,介绍一个具体的例子。

#include 
#include 
#include 
#include 
#include 
#include 
#include 

using namespace shark;
using namespace std;

int main(int argc, char **argv) {
    if(argc < 2) {
        cerr << "usage: " << argv[0] << " (filename)" << endl;
        exit(EXIT_FAILURE);
    }
    // read data
    ClassificationDataset data;
    try {
        importCSV(data, argv[1], LAST_COLUMN, ' ');
    } 
    catch (...) {
        cerr << "unable to read data from file " <<  argv[1] << endl;
        exit(EXIT_FAILURE);
    }

    cout << "number of data points: " << data.numberOfElements()
         << " number of classes: " << numberOfClasses(data)
         << " input dimension: " << inputDimension(data) << endl;

    // split data into training and test set
    ClassificationDataset dataTest = splitAtElement(data, static_cast<std::size_t>(.5 * data.numberOfElements()));
    cout << "training data points: " << data.numberOfElements() << endl;
    cout << "test data points: " << dataTest.numberOfElements() << endl;

    //create a binary search tree and initialize the search algorithm - a fast tree search
    KDTree tree(data.inputs());
    TreeNearestNeighborsunsigned int> algorithm(data,&tree);
    //instantiate the classifier
    const unsigned int K = 1; // number of neighbors for kNN
    NearestNeighborClassifier KNN(&algorithm,K);

    // evaluate classifier
    ZeroOneLoss<unsigned int> loss;
    Data<unsigned int> prediction = KNN(data.inputs());
    cout << K << "-KNN on training set accuracy: " << 1. - loss.eval(data.labels(), prediction) << endl;
    prediction = KNN(dataTest.inputs());
    cout << K << "-KNN on test set accuracy:     " << 1. - loss.eval(dataTest.labels(), prediction) << endl;
}

首先是根据输入的数据构造一棵kd树;之后是创建一个算法类TreeNearestNeighbor的对象,将构建好的kd树传入其中;最后是构建KNN分类器NearestNeighborClassifier,对传入的测试数据进行分类。

在最开始的时候,我们也说到过,可以使用基于核函数的距离的kd树。这样整个代码就会变成如下形式:

LinearKernel<RealVector> kernel;
KHCTree<RealVector> tree(data.inputs(), &kernel);
TreeNearestNeighbors<RealVector, unsigned int> algorithm(data, &tree);
NearestNeighborClassifier<RealVector> KNN(&algorithm, K);

注意到,这里只是kd树所对应的类不同了。

如果不使用树形结构进行搜索的话,同样需要定义相应的距离度量。并且不再是使用TreeNearestNeighbor这个算法类,而是使用SimpleNearestNeighbor类。对应的代码如下:

LinearKernel<> kernel;
SimpleNearestNeighbors<RealVector, unsigned int> algorithm(data, &kernel);
NearestNeighborClassifier<RealVector> KNN(&algorithm, K);

如果不想只输出一个类标签,而是想输出数据对于各个类的隶属度,可以使用SoftNearestNeighborClassifier这个分类器类。

你可能感兴趣的:(Shark源码分析,shark,KNN算法)