Luogu4689 [Ynoi2016]这是我自己的发明 【莫队】

题目链接:洛谷

又来做Ynoi里面的水题了。。。

首先换根的话是一个套路,首先以1为根dfs,然后画一画就知道以rt为根,x的子树是什么了。可以拆分为2个dfs连续段。

然后如果要计算\([l_1,r_1]\)\([l_2,r_2]\)的答案,那么就是那么做一个二维差分就可以改成\([1,r_1]\)\([1,r_2]\)的答案了。用\((r_1,r_2)\)做莫队就可以过了。

注意有一点,要去除那些不必要的询问,即\(r_1=0\)或者\(r_2=0\),这样就可以去掉大量的询问,不然会T掉3个点。

#include
#define Rint register int
using namespace std;
typedef long long LL;
const int N = 100003;
int n, m, tot, len, blo, rt, head[N], to[N << 1], nxt[N << 1], a[N], b[N], dfn[N], pre[N], dep[N], fa[N], siz[N], wson[N], top[N], tim;
LL ans[N * 5];
inline void add(int a, int b){
    static int cnt = 0;
    to[++ cnt] = b; nxt[cnt] = head[a]; head[a] = cnt;
}
inline void dfs1(int x){
    siz[x] = 1; 
    for(Rint i = head[x];i;i = nxt[i])
        if(to[i] != fa[x]){
            dep[to[i]] = dep[x] + 1; fa[to[i]] = x;
            dfs1(to[i]);
            siz[x] += siz[to[i]];
            if(siz[to[i]] > siz[wson[x]]) wson[x] = to[i];
        }
}
inline void dfs2(int x, int tp){
    top[x] = tp; dfn[x] = ++ tim; pre[tim] = x;
    if(wson[x]){
        dfs2(wson[x], tp);
        for(Rint i = head[x];i;i = nxt[i])
            if(to[i] != fa[x] && to[i] != wson[x])
                dfs2(to[i], to[i]);
    }
}
inline int calc(int x, int y){
    while(dep[x] > dep[y]){
        if(dep[top[x]] <= dep[y]) return wson[y];
        if(fa[top[x]] == y) return top[x];
        x = fa[top[x]];
    }
    return x;
}
struct Query {
    int l, r, id;
    bool flag;
    inline bool operator < (const Query &o) const {
        if(l / blo != o.l / blo) return l / blo < o.l / blo;
        if((l / blo) & 1) return r > o.r;
        return r < o.r;
    }
} que[N * 80];
inline void add(int l1, int r1, int l2, int r2, int id){
    if(r1 && r2){que[++ tot].l = r1; que[tot].r = r2; que[tot].id = id; que[tot].flag = false;}
    if(l1 > 1 && r2){que[++ tot].l = l1 - 1; que[tot].r = r2; que[tot].id = id; que[tot].flag = true;}
    if(r1 && l2 > 1){que[++ tot].l = r1; que[tot].r = l2 - 1; que[tot].id = id; que[tot].flag = true;}
    if(l1 > 1 && l2 > 1){que[++ tot].l = l1 - 1; que[tot].r = l2 - 1; que[tot].id = id; que[tot].flag = false;}
}
int ql = 0, qr = 0, cnt1[N], cnt2[N];
LL qans = 0;
inline void add1(int x){++ cnt1[x]; qans += cnt2[x];}
inline void del1(int x){-- cnt1[x]; qans -= cnt2[x];}
inline void add2(int x){++ cnt2[x]; qans += cnt1[x];}
inline void del2(int x){-- cnt2[x]; qans -= cnt1[x];}
int main(){
    scanf("%d%d", &n, &m); blo = sqrt(n);
    for(Rint i = 1;i <= n;i ++) scanf("%d", a + i), b[i] = a[i];
    sort(b + 1, b + n + 1);
    len = unique(b + 1, b + n + 1) - b - 1;
    for(Rint i = 1;i <= n;i ++) a[i] = lower_bound(b + 1, b + len + 1, a[i]) - b;
    for(Rint i = 1;i < n;i ++){
        int a, b; scanf("%d%d", &a, &b); add(a, b); add(b, a);
    }
    dfs1(1); dfs2(1, 1); rt = 1;
    int pos = 0;
    while(m --){
        int opt, x, y;
        scanf("%d", &opt);
        if(opt == 1) scanf("%d", &rt);
        else {
            scanf("%d%d", &x, &y); ++ pos;
            int l1[2], r1[2], l2[2], r2[2], cnt1 = 0, cnt2 = 0;
            if(rt == x) l1[0] = 1, r1[0] = n, cnt1 = 1;
            else if(dfn[rt] > dfn[x] && dfn[rt] < dfn[x] + siz[x]){
                int tmp = calc(rt, x);
                l1[0] = 1; r1[0] = dfn[tmp] - 1; l1[1] = dfn[tmp] + siz[tmp]; r1[1] = n; cnt1 = 2;
            } else l1[0] = dfn[x], r1[0] = dfn[x] + siz[x] - 1, cnt1 = 1;
            if(rt == y) l2[0] = 1, r2[0] = n, cnt2 = 1;
            else if(dfn[rt] > dfn[y] && dfn[rt] < dfn[y] + siz[y]){
                int tmp = calc(rt, y);
                l2[0] = 1; r2[0] = dfn[tmp] - 1; l2[1] = dfn[tmp] + siz[tmp]; r2[1] = n; cnt2 = 2;
            } else l2[0] = dfn[y], r2[0] = dfn[y] + siz[y] - 1, cnt2 = 1;
            for(Rint i = 0;i < cnt1;i ++)
                for(Rint j = 0;j < cnt2;j ++)
                    add(l1[i], r1[i], l2[j], r2[j], pos);
        }
    }
    for(Rint i = 1;i <= tot;i ++)
        if(que[i].l > que[i].r) swap(que[i].l, que[i].r);
    sort(que + 1, que + tot + 1);
    for(Rint i = 1;i <= tot;i ++){
        while(ql < que[i].l) add1(a[pre[++ ql]]);
        while(ql > que[i].l) del1(a[pre[ql --]]);
        while(qr < que[i].r) add2(a[pre[++ qr]]);
        while(qr > que[i].r) del2(a[pre[qr --]]);
        if(que[i].flag) ans[que[i].id] -= qans;
        else ans[que[i].id] += qans;
    }
    for(Rint i = 1;i <= pos;i ++) printf("%lld\n", ans[i]);
}

你可能感兴趣的:(Luogu4689 [Ynoi2016]这是我自己的发明 【莫队】)