2020杭电多校第六场 A Very Easy Graph Problem 点分治 (HDU 6832)

A Very Easy Graph Problem

不知道有没有人跟我一样点分治写的…

题解

根据题意,第 i i i条边长度为 2 i 2^i 2i,且 2 1 + 2 2 + ⋅ ⋅ ⋅ + 2 i − 1 = 2 i − 1 < 2 i 2^1+2^2+···+2^{i - 1} =2^i-1<2^i 21+22++2i1=2i1<2i,即后面加进去的边要大于前面所有边之和
那么显然,如果在加入当前边 ( u , v , 2 i ) (u,v,2^i) (u,v,2i)之前, u u u v v v已经连通了,我们就不用再加这条边了(加之前能连通,而且根据上面说的性质,可以知道最短路一定不会经过这条边),所以我们按照这样建立的图,其实就是一颗树

再来看题目要求的式子:
∑ i = 1 n ∑ j = 1 n d ( i , j ) × [ a i = 1 ∧ a j = 0 ] \displaystyle\sum_{i = 1} ^ {n}\displaystyle\sum_{j = 1} ^ {n}d(i, j) × [a_i = 1∧a_j=0] i=1nj=1nd(i,j)×[ai=1aj=0]
很容易想到点分治(处理树上点对问题)
其实就是求 a i = 1 a_i=1 ai=1的所有点与 a j = 1 a_j=1 aj=1的所有点的距离之和

关键部分:

ll ans, sum[2], cnt[2], tsum[2], tcnt[2];
//sum[1]记录之前搜过的子树中a[i]=1的dis之和, sum[0]同理
//cnt[1]记录之前搜过的子树中a[i]=1的点的个数, cnt[0]同理
void getAns(int u, int fa, ll dis) {
    tsum[a[u]] = (tsum[a[u]] + dis) % mod;
    tcnt[a[u]] = (tcnt[a[u]] + 1) % mod;
    ans = ((ans + sum[a[u] ^ 1]) % mod + cnt[a[u] ^ 1] * dis % mod) % mod;
    //答案就是加上之前搜的a[i]!=a[u]的点到根的距离之和+当前点到根的距离*之前搜的a[i]!=a[u]的点的个数
    for (int i = head[u], v; i; i = e[i].nxt)
        if (!vis[v = e[i].to] && v != fa)
            getAns(v, u, (dis + e[i].w) % mod);
}

void calc(int u) {
    sum[0] = sum[1] = cnt[0] = cnt[1] = 0;
    cnt[a[u]] = 1;
    for (int i = head[u], v; i; i = e[i].nxt)
        if (!vis[v = e[i].to]) {
            tsum[0] = tsum[1] = tcnt[0] = tcnt[1] = 0;
            getAns(v, u, e[i].w);
            //延迟加入,否则贡献有重复
            sum[0] = (sum[0] + tsum[0]) % mod;
            sum[1] = (sum[1] + tsum[1]) % mod;
            cnt[0] = (cnt[0] + tcnt[0]) % mod;
            cnt[1] = (cnt[1] + tcnt[1]) % mod;
        }
}

代码

#include
using namespace std;
typedef long long ll;
const int N = 2e5 + 10;
const ll mod = 1e9 + 7;

struct Edge {
    int nxt, to; ll w;
} e[N << 1];
int head[N], tot;
void add(int u, int v, ll w) { e[++tot] = Edge{head[u], v, w}; head[u] = tot; }


int n, m;
int a[N], pre[N];
int find(int x) { return x == pre[x] ? x : pre[x] = find(pre[x]); }

int maxp[N], siz[N], vis[N], rt = 1;

void getRt(int u, int fa, int all) {
    siz[u] = 1, maxp[u] = 0;
    for (int i = head[u], v = e[i].to; i; i = e[i].nxt, v = e[i].to)
        if (v != fa && !vis[v]) {
            getRt(v, u, all);
            siz[u] += siz[v];
            maxp[u] = max(maxp[u], siz[v]);
        }
    maxp[u] = max(maxp[u], all - siz[u]);
    if (maxp[u] < maxp[rt]) rt = u;
}

//这之间是关键--------------------别的都是板子
ll ans, sum[2], cnt[2], tsum[2], tcnt[2];
//sum[1]记录之前搜过的子树中a[i]=1的dis之和, sum[0]同理
//cnt[1]记录之前搜过的子树中a[i]=1的点的个数, cnt[0]同理
void getAns(int u, int fa, ll dis) {
    tsum[a[u]] = (tsum[a[u]] + dis) % mod;
    tcnt[a[u]] = (tcnt[a[u]] + 1) % mod;
    ans = ((ans + sum[a[u] ^ 1]) % mod + cnt[a[u] ^ 1] * dis % mod) % mod;
    //答案就是加上之前搜的a[i]!=a[u]的点到根的距离之和+当前点到根的距离*之前搜的a[i]!=a[u]的点的个数
    for (int i = head[u], v; i; i = e[i].nxt)
        if (!vis[v = e[i].to] && v != fa)
            getAns(v, u, (dis + e[i].w) % mod);
}

void calc(int u) {
    sum[0] = sum[1] = cnt[0] = cnt[1] = 0;
    cnt[a[u]] = 1;
    for (int i = head[u], v; i; i = e[i].nxt)
        if (!vis[v = e[i].to]) {
            tsum[0] = tsum[1] = tcnt[0] = tcnt[1] = 0;
            getAns(v, u, e[i].w);
            //延迟加入,否则贡献有重复
            sum[0] = (sum[0] + tsum[0]) % mod;
            sum[1] = (sum[1] + tsum[1]) % mod;
            cnt[0] = (cnt[0] + tcnt[0]) % mod;
            cnt[1] = (cnt[1] + tcnt[1]) % mod;
        }
}
//------------------------

void dfs(int u) {
    vis[u] = 1;
    calc(u);
    for (int i = head[u], v; i; i = e[i].nxt)
        if (!vis[v = e[i].to]) {
            maxp[rt = 0] = N; getRt(v, 0, siz[v]);
            dfs(rt);
        }
}

int main() {

    int T; scanf("%d", &T);
    while (T--) {
        scanf("%d%d", &n, &m);
        for (int i = 0; i <= n; i++) pre[i] = i, head[i] = vis[i] = 0; tot = 0;
        for (int i = 1; i <= n; i++)
            scanf("%d", &a[i]);
        ll pow2 = 2;
        for (int i = 1; i <= m; i++, pow2 = pow2 * 2ll % mod) {
            int x, y; scanf("%d%d", &x, &y);
            int u = find(x), v = find(y);
            if (u == v) continue; pre[v] = u;
            add(x, y, pow2); add(y, x, pow2);
        }
        maxp[rt = 0] = N; getRt(1, 0, N);
        ans = 0; dfs(rt);
        printf("%lld\n", ans);
    }

    return 0;
}

你可能感兴趣的:(#,点分治,树和森林)