HDU 2966 KDtree模板

题意:

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=2966
求每个点到最近邻点的距离平方


思路:

KDtree板子


代码:

// x维数为2的KDtree模板
#include 
using namespace std;
typedef long long LL;
const LL INF = 0x3f3f3f3f3f3f3f3f;
const int MAXN = 1e5 + 10;

struct Node {
    int lson, rson;
    LL Min[2], Max[2], x[2];
    int id;
} kdt[MAXN << 1], tmp;

int root, cmp_x;
LL ans, xx0, xx1;

bool cmp (const Node &a, const Node &b) {
    return a.x[cmp_x] < b.x[cmp_x] || (a.x[cmp_x] == b.x[cmp_x] && a.x[cmp_x^1] < b.x[cmp_x^1]);
}

// 更新每个结点的边界信息
void pushUp(int u, int v) {
    for (int i = 0; i < 2; i++) kdt[u].Min[i] = min(kdt[u].Min[i], kdt[v].Min[i]);
    for (int i = 0; i < 2; i++) kdt[u].Max[i] = max(kdt[u].Max[i], kdt[v].Max[i]);
}

int kdtBuild(int l, int r, int X) {
    int mid = (l + r) >> 1;
    kdt[mid].lson = kdt[mid].rson = 0;
    cmp_x = X;
    nth_element(kdt + l + 1, kdt + mid + 1, kdt + r + 1, cmp); // 将编号为mid的元素放在中间,比它小的放在前面,比它大的放后面
    kdt[mid].Min[0] = kdt[mid].Max[0] = kdt[mid].x[0];
    kdt[mid].Min[1] = kdt[mid].Max[1] = kdt[mid].x[1];
    if (l != mid) kdt[mid].lson = kdtBuild(l, mid - 1, X ^ 1);
    if (r != mid) kdt[mid].rson = kdtBuild(mid + 1, r, X ^ 1);
    if (kdt[mid].lson) pushUp(mid, kdt[mid].lson);
    if (kdt[mid].rson) pushUp(mid, kdt[mid].rson);
    return mid;
}

// 插入新的结点
void kdtInsert(int now) {
    int X = 0, p = root;
    while (true) {
        pushUp(p, now);
        if (kdt[now].x[X] < kdt[p].x[X]) {
            if (!kdt[p].lson) {
                kdt[p].lson = now;
                return;
            }
            else p = kdt[p].lson;
        }
        else {
            if (!kdt[p].rson) {
                kdt[p].rson = now;
                return;
            }
            else p = kdt[p].rson;
        }
    }
}

// 点(x,y)在结点id的边界范围内能得到的最大距离上界
LL getMaxDis(int id, LL x0, LL x1) {
    LL res = 0;
    if (x0 < kdt[id].Min[0]) res += (kdt[id].Min[0] - x0) * (kdt[id].Min[0] - x0);
    if (x0 > kdt[id].Max[0]) res += (kdt[id].Max[0] - x0) * (kdt[id].Max[0] - x0);
    if (x1 < kdt[id].Min[1]) res += (kdt[id].Min[1] - x1) * (kdt[id].Min[1] - x1);
    if (x1 > kdt[id].Max[1]) res += (kdt[id].Max[1] - x1) * (kdt[id].Max[1] - x1);
    return res;
}

LL dist(int id, LL x0, LL x1) {
    return (kdt[id].x[0] - x0) * (kdt[id].x[0] - x0) + (kdt[id].x[1] - x1) * (kdt[id].x[1] - x1);
}

void kdtQuery(int p) {
    LL dl = INF, dr = INF, d;
    d = dist(p, xx0, xx1);
    if (kdt[p].x[0] == xx0 && kdt[p].x[1] == xx1) d = INF;  // 查询(x,y)时要将(x,y)到自己的距离设为INF
    ans = min(ans, d);
    if (kdt[p].lson) dl = getMaxDis(kdt[p].lson, xx0, xx1);
    if (kdt[p].rson) dr = getMaxDis(kdt[p].rson, xx0, xx1);
    if (dl < dr) {
        if (dl < ans) kdtQuery(kdt[p].lson);
        if (dr < ans) kdtQuery(kdt[p].rson);
    }
    else {
        if (dr < ans) kdtQuery(kdt[p].rson);
        if (dl < ans) kdtQuery(kdt[p].lson);
    }
}

LL answer[MAXN];

int main() {
    //freopen("in.txt", "r", stdin);
    int T;
    scanf("%d", &T);
    while (T--) {
        int n;
        scanf("%d", &n);
        for (int i = 1; i <= n; i++) {
            scanf("%I64d%I64d", &kdt[i].x[0], &kdt[i].x[1]);
            kdt[i].id = i;
        }
        root = kdtBuild(1, n, 0);
        for (int i = 1; i <= n; i++) {
            ans = INF;
            xx0 = kdt[i].x[0]; xx1 = kdt[i].x[1];
            kdtQuery(root);
            answer[kdt[i].id] = ans;
           // printf("---%d\n", ans);
        }
        for (int i = 1; i <= n; i++)
            printf("%I64d\n", answer[i]);
    }
    return 0;
}

你可能感兴趣的:(KDtree)