2019 CCPC-Final G.Game on the Tree(长链剖分+DP)

题目大意

两个人在一个以1为根的树上玩游戏,一开始硬币在1。然后每一轮,如果当前硬币在u,当前的人可以选择一个v把硬币移动到v,条件是移动的距离要大于上一轮的人移动的距离。第一个人可以随便移动。
问给出的树有多少个以1为根的子图可以让后手必胜。
n ≤ 2 e 5 n\le2e5 n2e5

解题思路

这题很容易推出来当1是树直径的中点的时候后手必胜。
然后想要得到1是树直径中点的方案数,可以先得到1的每个儿子 v v v的深度为 i ( 0 < = i < l e n [ v ] ) i(0<=ii(0<=i<len[v])的方案数,再根据这个 d p dp dp
f ( v , j ) f(v,j) f(v,j)表示结点 v v v最大深度为 j j j的子图方案数进行DP。
利用长链剖分优化这个转移,注意到长链剖分过程中,轻儿子转移完对应的深度之后,父节点还有一些数据是需要更新的。那么我们开一个额外的懒标记数组去标记,每当要用到对应的值的时候我们就更新它并让懒标记向后传递。
得到了所有的 f ( v , j ) f(v,j) f(v,j)之后我们用 d p ( i , 0 ) dp(i,0) dp(i,0)表示只有根1只有一个子树最大深度为 i i i且所有子树中最大的深度为 i i i的方案数, d p ( i , 1 ) dp(i,1) dp(i,1)表示根1有大于等于2个子树最大深度为 i i i且所有子树中最大深度为 i i i的方案数。和前面类似利用懒标记去转移。
ps:一定要注意懒标记的细节……多传或者漏传都会wa傻眼
代码后面提供1组测试数据

#include
#define ll long long
using namespace std;
const int maxn = 2e5 + 50;
const ll mod = 1e9 + 7;
ll temp[maxn*8], *f[maxn], *ex[maxn], *sz[maxn], *id, dp[maxn][2], exdp[maxn][2];
int len[maxn], son[maxn];
vector<int> g[maxn];
int n;
void DFS(int u, int fa){
    for(int i = 0; i < g[u].size(); ++i){
        int v = g[u][i];
        if(v == fa) continue;
        DFS(v, u);
        if(len[v] > len[son[u]]) son[u] = v;
    }len[u] = len[son[u]]+1; return;
}
void init(){
    cin>>n;
    for(int i = 0; i <= n+1; ++i) g[i].clear(), son[i] = 0, len[i] = 0, exdp[i][0] = exdp[i][1] = 1;
    for(int i = 1; i < n; ++i){
        int u, v; scanf("%d%d", &u, &v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    DFS(1, 1);
}
inline void update(int u, int j){
    if(ex[u][j] == 1) return;
    f[u][j] = f[u][j]*ex[u][j]%mod;
    ex[u][j+1] = ex[u][j+1]*ex[u][j]%mod;
    ex[u][j] = 1;
}
ll tt[maxn];
void dfs(int u, int fa){
    f[u][0] = 1;
    ex[u][0] = 1;
    if(!son[u]) return;
    f[son[u]] = f[u] + 1;
    ex[son[u]] = ex[u] + 1;
    dfs(son[u], u);
    for(int i = 0; i < g[u].size(); ++i){
        int v = g[u][i];
        if(v == fa || v == son[u]) continue;
        f[v] = id; id += len[v]+1;
        ex[v] = id; id += len[v]+1;
        dfs(v, u);
        tt[0] = f[u][0];
        for(int j = 0; j < len[v]; ++j){
            update(v, j);  update(u, j);
            if(j > 0) tt[j] = (tt[j-1]+f[u][j])%mod, f[v][j] = (f[v][j]+f[v][j-1])%mod;
        }
        for(int j = 0; j < len[v]; ++j){
            update(u, j+1);//write as update(u,j), waµã
            if(j > 0)
                f[u][j+1] = (f[u][j+1]*(f[v][j]+1)%mod + tt[j]*(f[v][j]-f[v][j-1])%mod)%mod;
            else
                f[u][j+1] = (f[u][j+1]*(f[v][j]+1)%mod + tt[j]*f[v][j]%mod)%mod;
        }
        ex[u][len[v]+1] = ex[u][len[v]+1]*(f[v][len[v]-1]+1)%mod;
    }
}
int ca = 0;
inline void update(int i){
    dp[i][0] = dp[i][0]*exdp[i][0]%mod;
    exdp[i+1][0] = exdp[i][0]*exdp[i+1][0]%mod;
    exdp[i][0] = 1;

    dp[i][1] = dp[i][1]*exdp[i][1]%mod;
    exdp[i+1][1] = exdp[i][1]*exdp[i+1][1]%mod;
    exdp[i][1] = 1;
    return;
}
void sol(){
    if(g[1].size() < 2){
        printf("Case #%d: %d\n", ++ca, 1); return;
    }
    for(int i = 1; i <= len[1]; ++i) dp[i][0] = dp[i][1] = 0;
    dp[0][1] = 1;
    for(int i = 0; i < g[1].size(); ++i){
        int v = g[1][i];
        f[v] = id = temp; id += len[v]+1;
        ex[v] = id; id += len[v]+1;
        dfs(v, 1);
        for(int j = 0; j < len[v]; ++j){
            update(v, j);
            if(j > 0) f[v][j] = (f[v][j] + f[v][j-1])%mod;
        }
        tt[0] = dp[0][0] + dp[0][1];
        for(int j = 1; j <= len[v]; ++j) {
            update(j);
            tt[j] = (tt[j-1] + dp[j][0] + dp[j][1])%mod;
        }
        for(int j = 0; j < len[v]; ++j){
            if(j > 0) dp[j+1][1] = (dp[j+1][1] + dp[j+1][1]*f[v][j]%mod + dp[j+1][0]*(f[v][j]-f[v][j-1])%mod)%mod;
            else dp[j+1][1] = (dp[j+1][1] + dp[j+1][1]*f[v][j]%mod + dp[j+1][0]*f[v][j]%mod)%mod;

            if(j > 0) dp[j+1][0] = (dp[j+1][0]*(f[v][j-1] + 1)%mod + tt[j]*(f[v][j]-f[v][j-1])%mod)%mod;
            else dp[j+1][0] = (dp[j+1][0] + tt[j]*f[v][j])%mod;
        }
        exdp[len[v]+1][0] = exdp[len[v]+1][0]*(f[v][len[v]-1]+1)%mod;
        exdp[len[v]+1][1] = exdp[len[v]+1][1]*(f[v][len[v]-1]+1)%mod;
    }
    ll ans = 1;
    for(int i = 1; i <= len[1]; ++i) update(i), ans = (ans + dp[i][1])%mod;
    ans = (ans + mod)%mod;
    printf("Case #%d: %lld\n", ++ca, ans);
}
int main()
{
    //freopen("1.in", "r", stdin);
    int T;cin>>T;
    while(T--){
        init();
        sol();
    }
}
/*
2
14
1 2
2 3
3 4
4 5
3 6
2 7
7 8
7 9
9 10
1 11
11 12
12 13
13 14


13
1 2
2 3
3 4
4 5
2 6
6 7
7 8
1 9
9 10
10 11
11 12
3 13
*/

你可能感兴趣的:(启发式算法,数据结构,dp)