[CSU 1915 John and his farm]树形DP+LCA

[CSU 1915 John and his farm]树形DP+LCA

分类:Tree DP LCA

1. 题目链接

[CSU 1915 John and his farm]

2. 题意描述

有一棵 N 个节点的树,树上每条边长度为1。
现在需要等概率地随机地在两个顶点之间加一条边。
M 次询问。
每次查询给定两个顶点 u,v 。求增加一条边,保证顶点 u,v 在一个环内的条件下,求环的长度的数学期望。
要求结果保证误差在 106 以内。
数据范围: (2N,M200000)

3. 解题思路

首先,请见下图。
[CSU 1915 John and his farm]树形DP+LCA_第1张图片
然后,现在的问题就是求sum,siz数组。首先,dfs求出所有节点为根节点的子树的sum,siz。
另外,还需要一个dfs,求出以当前节点为根节点时,整棵树的dep之和,记为all。转移方程是:all[v] = all[u] + (n - 1 - siz[v]) - (siz[v] - 1);

现在就要分两种情况讨论(假设dep[u]<=dep[v]):

  • lca(u, v) !=u, 这个情况比较简单,T(u),T(v) 的sum(T(u)), sum(T(v)), siz(T(u)), siz(T(v)),直接就是sum[u], sum[v], siz[u], siz[v]。
  • lca(u, v)==u, 此时T(v) 的sum(T(v))=sum[v], siz(T(v))=siz[v],但是,siz(T(u)) = n - siz[w], sum(T(u))=all[u] - sum[w] - siz[w]; (顶点w是在从u到v的链上,且是u的儿子)。

看起来比较复杂,但是自己手算理解一下,就很简单了。
这题,还需要注意sum和all 会爆int。

4. 实现代码

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

using namespace std;

typedef long long LL;
typedef long double LB;
typedef pair<int, int> PII;
typedef pair PLL;
typedef vector<int> VI;

const int INF = 0x3f3f3f3f;
const LL INFL = 0x3f3f3f3f3f3f3f3fLL;

void debug() { cout << endl; }
template<typename T, typename ...R> void debug (T f, R ...r) { cout << "[" << f << "]"; debug (r...); }


const int MAXN = 1e5 + 5;
const int MAXM = 20;

int n, m;
struct Edge {
    int v, next;
} edge[MAXN << 1];
int head[MAXN], tot;
int dep[MAXN], siz[MAXN], fa[MAXN][MAXM];
LL all[MAXN], sum[MAXN];
int root;

void init_edge() {
    tot = 0;
    memset(head, -1, sizeof(head));
}
inline void add_edge(int u, int v) {
    edge[tot] = Edge {v, head[u]};
    head[u] = tot ++;
}
void dfs(int u, int pre, int d) {
    int v;
    siz[u] = 1;
    dep[u] = d;
    sum[u] = 0;
    fa[u][0] = pre;
    for(int i = head[u]; ~i; i = edge[i].next) {
        v = edge[i].v;
        if(v == pre) continue;
        dfs(v, u, d + 1);
        siz[u] += siz[v];
        sum[u] += sum[v];
        sum[u] += siz[v];
    }
}

void dfs2(int u, int pre) {
    int v;
    for(int i = head[u]; ~i; i = edge[i].next) {
        v = edge[i].v;
        if(v == pre) continue;
        all[v] = all[u] + (n - 1 - siz[v]) - (siz[v] - 1);
        dfs2(v, u);
    }
}

void lca_init() {
    for(int j = 1; j < MAXM; ++j) {
        for(int i = 1; i <= n; ++i) {
            fa[i][j] = fa[fa[i][j - 1]][j - 1];
        }
    }
}

int lca(int u, int v) {
    while(dep[u] != dep[v]) {
        if(dep[u] < dep[v]) swap(u, v);
        int d = dep[u] - dep[v];
        for(int i = 0; i < MAXM; i++) {
            if(d >> i & 1) u = fa[u][i];
        }
    }
    if(u == v) return u;
    for(int i = MAXM - 1; i >= 0; i--) {
        if(fa[u][i] != fa[v][i]) {
            u = fa[u][i];
            v = fa[v][i];
        }
    }
    return fa[u][0];
}

int son(int u, int v) {
    while(dep[v] > dep[u] + 1) {
        int w = v;
        for(int j = 0; j < MAXM; ++j) {
            if(dep[fa[v][j]] < dep[u] + 1) break;
            w = fa[v][j];
        }
        v = w;
    }
    return v;
}

int main() {
#ifdef ___LOCAL_WONZY___
    freopen ("input.txt", "r", stdin);
#endif // ___LOCAL_WONZY___
    int u, v, w;
    while(~scanf("%d %d", &n, &m)) {
        init_edge();
        for(int i = 1; i <= n - 1; ++i) {
            scanf("%d %d", &u, &v);
            add_edge(u, v);
            add_edge(v, u);
        }

        dfs(root = 1, 0, 0);
        all[root] = sum[root];
        dfs2(root, 0);

        lca_init();
//        for(int i = 1; i <= n; ++i) {
//            printf("[%d: sum=%d siz=%d all=%d]\n", i, sum[i], siz[i], all[i]);
//        }
        while(m --) {
            scanf("%d %d", &u, &v);
            if(dep[u] > dep[v]) swap(u, v);
            w = lca(u, v);
            int dist, sizu, sizv;
            LL sumu, sumv;
            double ans;
            if(w != u) {
                /** 有lca **/
                dist = dep[u] + dep[v] - 2 * dep[w];
                sizu = siz[u];
                sumu = sum[u];
                sizv = siz[v];
                sumv = sum[v];
            } else {
                /**一条链**/
                dist = dep[v] - dep[u];
                w = son(u, v);
                sizu = n - siz[w];
                sumu = all[u] - sum[w] - siz[w];
                sizv = siz[v];
                sumv = sum[v];
            }
            ans = 1.0 + dist + 1.0 * ((LL)sizu * sumv + (LL)sizv * sumu) / ((LL)sizu * sizv);
            printf("%.8f\n", ans);
        }
    }
#ifdef ___LOCAL_WONZY___
    cout << "Time elapsed: " << 1.0 * clock() / CLOCKS_PER_SEC * 1000 << " ms." << endl;
#endif // ___LOCAL_WONZY___
    return 0;
}

你可能感兴趣的:(ACM____数据结构,ACM____动态规划,dp,lca)