SPOJ 913 Query on a tree II

SPOJ_913

    这个题目应该也可以树链剖分去做,只不过感觉在KTH这个操作还是用link-cut-tree更好写一些。

#include<stdio.h>
#include<string.h>
#define MAXD 10010
#define MAXM 20010
int N, q[MAXD], first[MAXD], e, next[MAXM], v[MAXM], w[MAXM];
struct Splay
{
    int pre, ls, rs, size, sum, key;
    bool root;
    void update(); void zig(int ); void zag(int ); void splay(int );
    void renew()
    {
        root = true;
        pre = ls = rs = 0;
        size = 1;
    }
}sp[MAXD];
void Splay::update()
{
    size = sp[ls].size + sp[rs].size + 1;
    sum = sp[ls].sum + sp[rs].sum + key;
}
void Splay::zig(int x)
{
    int y = rs, fa = pre;
    rs = sp[y].ls, sp[rs].pre = x;
    sp[y].ls = x, pre = y;
    sp[y].pre = fa;
    if(root)
        root = false, sp[y].root = true;
    else
        sp[fa].rs == x ? sp[fa].rs = y : sp[fa].ls = y;
    update();
}
void Splay::zag(int x)
{
    int y = ls, fa = pre;
    ls = sp[y].rs, sp[ls].pre = x;
    sp[y].rs = x, pre = y;
    sp[y].pre = fa;
    if(root)
        root = false, sp[y].root = true;
    else
        sp[fa].rs == x ? sp[fa].rs = y : sp[fa].ls = y;
    update();
}
void Splay::splay(int x)
{
    int y, z;
    for(; !root; )
    {
        y = pre;
        if(sp[y].root)
            sp[y].rs == x ? sp[y].zig(y) : sp[y].zag(y);
        else
        {
            z = sp[y].pre;
            if(sp[z].rs == y)
            {
                if(sp[y].rs == x)
                    sp[z].zig(z), sp[y].zig(y);
                else
                    sp[y].zag(y), sp[z].zig(z);
            }
            else
            {
                if(sp[y].ls == x)
                    sp[z].zag(z), sp[y].zag(y);
                else
                    sp[y].zig(y), sp[z].zag(z);
            }
        }
    }
    update();
}
void prepare()
{
    int i, j, x, rear = 0;
    sp[0].size = sp[0].key = sp[0].sum = 0;
    q[rear ++] = 1;
    sp[1].renew(), sp[1].pre = 0, sp[1].key = sp[1].sum = 0;
    for(i = 0; i < rear; i ++)
    {
        x = q[i];
        for(j = first[x]; j != -1; j = next[j])
            if(v[j] != sp[x].pre)
            {
                sp[v[j]].renew(), sp[v[j]].pre = x, sp[v[j]].key = sp[v[j]].sum = w[j];
                q[rear ++] = v[j];
            }
    }
}
void add(int x, int y, int z)
{
    v[e] = y, w[e] = z;
    next[e] = first[x], first[x] = e ++;
}
void init()
{
    int i, x, y, z;
    scanf("%d", &N);
    memset(first, -1, sizeof(first));
    e = 0;
    for(i = 1; i < N; i ++)
    {
        scanf("%d%d%d", &x, &y, &z);
        add(x, y, z), add(y, x, z);
    }
    prepare();
}
void access(int x)
{
    int fx;
    for(fx = x, x = 0; fx != 0; x = fx, fx = sp[x].pre)
    {
        sp[fx].splay(fx);
        sp[sp[fx].rs].root = true;
        sp[fx].rs = x, sp[x].root = false;
        sp[fx].update();
    }
}
void dist(int x, int y)
{
    int fy;
    access(x);
    for(fy = y, y = 0; fy != 0; y = fy, fy = sp[y].pre)
    {
        sp[fy].splay(fy);
        if(sp[fy].pre == 0)
            printf("%d\n", sp[sp[fy].rs].sum + sp[y].sum);
        sp[sp[fy].rs].root = true;
        sp[fy].rs = y, sp[y].root = false;
        sp[fy].update();
    }
}
int Search(int cur, int k)
{
    int n = sp[sp[cur].ls].size;
    if(k == n + 1)
        return cur;
    else if(k <= n)
        return Search(sp[cur].ls, k);
    else
        return Search(sp[cur].rs, k - n - 1);
}
void kth(int x, int y, int k)
{
    int fy;
    access(x);
    for(fy = y, y = 0; fy != 0; y = fy, fy = sp[y].pre)
    {
        sp[fy].splay(fy);
        if(sp[fy].pre == 0)
        {
            int n = sp[sp[fy].rs].size;
            if(k == n + 1)
                printf("%d\n", fy);
            else if(k <= n)
                printf("%d\n", Search(sp[fy].rs, n - k + 1));
            else
                printf("%d\n", Search(y, k - n - 1));
        }
        sp[sp[fy].rs].root = true;
        sp[fy].rs = y, sp[y].root = false;
        sp[fy].update();
    }
}
void solve()
{
    int a, b, k;
    char op[10];
    for(;;)
    {
        scanf("%s", op);
        if(op[1] == 'O')
            break;
        if(op[1] == 'I')
        {
            scanf("%d%d", &a, &b);
            dist(a, b);
        }
        else
        {
            scanf("%d%d%d", &a, &b, &k);
            kth(a, b, k);
        }
    }
}
int main()
{
    int t;
    scanf("%d", &t);
    while(t --)
    {
        init();
        solve();
    }
    return 0;
}

你可能感兴趣的:(query)