k-d树学习

题目:http://acm.hdu.edu.cn/showproblem.php?pid=4347

k-d树在acm界好像不是很常见的样子,至于到底会不会考到我也不清楚,我遇到的题目有两个,第一个是今年长春邀请赛的时候的D题,题解是这么说的

D:(Fire Station Problem),7,本意为KD tree 过,但是由于坐标比较小实际上部分队伍水过
没找到当时的网址,抱歉。

还有一个题目我遇到的就是今年的多小第5场,也就是开始给的题目连接。

学习资料:http://en.wikipedia.org/wiki/Kd-tree#Construction  只有这个了,然后我在网上搜了一篇题解看的代码。

题目大意:给出N多个点,然后给出T次询问,每一次给出一个坐标,问N多个点中离这个点最近的k个点坐标是??

这个题目应该是k-d 树的模板题吧

再加一个资料:来源于大牛 木子日匀 http://www.mzry1992.com/blog/miao/kd%E6%A0%91.html 大牛写的很好!!

k-d树还有几个操作,第一,插入点,第二,删除点,第三就是这个题的查询最近点(又叫KNN——Thenearest neighbour search)


kd树是循环依次对每一维进行二分的一颗二叉树!


贴一下我的代码,如果你不想看wiki上的英文,那你就只能看代码了!当然,如何建树你是必须看的!!

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <set>
using namespace std;
const int MAXN = 50010;
const int MAXK = 5;

int n, m, k;
struct Point
{
    int p[MAXK];

    inline void input()
    {
        for(int i=0;i<k;++i)
            scanf("%d", &p[i]);
    }
    inline void output() const
    {
        for(int i=0;i<k;++i)
        {
            if(i) printf(" ");
            printf("%d", p[i]);
        }
        printf("\n");
    }
}point[MAXN], searchPoint;

int pointSet[MAXN];
set<Point> ans;

inline int getDistance(const Point &a, const Point &b)
{
    int ans = 0;
    for(int i=0;i<k;++i)
    {
        ans += (a.p[i] - b.p[i]) * (a.p[i] - b.p[i]);
    }
    return ans;
}

bool operator < (const Point &a, const Point &b)
{
    return getDistance(a, searchPoint) < getDistance(b, searchPoint);
}

struct TreeNode
{
    int index;
    bool left, right;
}tree[MAXN*2];//注意这个地方时MAXN*2 我看的那个原作者的这个地方没*2肯定是不行的

int cmpArgs;
bool compare(int x, int y)
{
    return point[x].p[cmpArgs] < point[y].p[cmpArgs];
}
bool build(int l,int r,int rt,int dep)//建树我使用了类似于建线段树的那种建法
{
    int dim = dep%k;
    if(r < l) return 0;
    if(l == r)
    {
        tree[rt].index = pointSet[l];
        tree[rt].left = tree[rt].right = 0;
        return 1;
    }
    cmpArgs = dim;
    sort(pointSet+l,pointSet + r + 1,compare);

    int m = (l+r)>>1;

    tree[rt].index = pointSet[m];
    tree[rt].left = build(l,m-1,rt<<1,dep+1);
    tree[rt].right = build(m+1,r,rt<<1|1,dep+1);
    return true;
}
#ifndef ONLINE_JUDGE
void printTree(int x, int depth)
{
    printf("Layer %d: ", depth);
    point[tree[x].index].output();
    if(tree[x].left)
    {
        printTree(x << 1, depth + 1);
    }
    if(tree[x].right)
    {
        printTree((x << 1) + 1, depth + 1);
    }
}
#endif

inline void insertPossibility(const Point &point)
{
    ans.insert(point);
    if(ans.size() > m)
    {
        set<Point>::reverse_iterator it = ans.rbegin();
        ans.erase(*it);
    }
}

void queryTree(int x,int dep)//总感觉查询的代码可以优化,但是没想出来怎么减少代码量
{
    insertPossibility(point[tree[x].index]);
    if(!tree[x].left && !tree[x].right)
    {
        return;
    }
    int dim = dep%k;
    bool flag = false;
    int dist1 = searchPoint.p[dim] - point[tree[x].index].p[dim];
    dist1 *= dist1;
    if(searchPoint.p[dim] < point[tree[x].index].p[dim])
    {
        if(tree[x].left)
            queryTree(x << 1,dep+1);
        if(tree[x].right)
        {
            if(ans.size() < m)
                flag = true;
            else
            {
                set<Point>::reverse_iterator it = ans.rbegin();
                Point temp = *it;
                int dist2 = getDistance(temp, searchPoint);
                if(dist1 <= dist2)
                    flag = true;
            }
            if(flag)
                queryTree((x << 1) + 1,dep+1);
        }
    }
    else
    {
        if(tree[x].right)
            queryTree((x << 1) + 1,dep+1);
        if(tree[x].left)
        {
            if(ans.size() < m)
                flag = true;
            else
            {
                set<Point>::reverse_iterator it = ans.rbegin();
                Point temp = *it;
                int dist2 = getDistance(temp, searchPoint);
                if(dist1 <= dist2)
                    flag = true;

            }
            if(flag)
                queryTree(x << 1,dep+1);
        }
    }
}

void solve()
{
    ans.clear();
    queryTree(1,0);
}

int main()
{
    int t;
    while(~scanf("%d%d",&n,&k))
    {
        for(int i=0;i<n;++i)
        {
            point[i].input();
            pointSet[i] = i;
        }
        build(0,n-1,1, 0);

        #ifndef ONLINE_JUDGE
            //printTree(1, 0);
        #endif

        scanf("%d", &t);
        while(t--)
        {
            searchPoint.input();
            scanf("%d", &m);
            solve();
            printf("the closest %d points are:\n", m);
            for(set<Point>::iterator it=ans.begin();it!=ans.end();++it)
            {
                it->output();
            }
        }
    }
    return 0;
}


你可能感兴趣的:(k-d树学习)