Codeforces Round #606 (Div. 1, based on Technocup 2020 Elimination Round 4)

场次链接

A、B、C即div2的C、E、F,需要请查看div2.链接

D. Tree Elimination
题目链接
你有n个数,初始队列为空,后面n-1行每行给出 x , y x,y x,y,代表 x , y x,y x,y之间有边,最后构成一个树,你能做一种操作,如果一条边的2个节点都没有在队列中,你可以将其中一个加入队列,问最后你能得到多少个不同的队列。
数据范围 2 ≤ n ≤ 2 ∗ 1 0 5 2\leq n\leq 2*10^5 2n2105, 1 ≤ x i , y i ≤ n 1\leq x_i,y_i\leq n 1xi,yin
解 树形dp, d p [ i ] [ 0 ] dp[i][0] dp[i][0]代表 i i i在与其父节点比较前就被加入队列, d p [ i ] [ 1 ] dp[i][1] dp[i][1]代表 i i i在和父节点比较时被加入队列, d p [ i ] [ 2 ] dp[i][2] dp[i][2]代表 i i i在与父节点比较后被加入队列,或是没有被加入。
设l为该节点连接的子节点长度,k为父节点的位置。

d p [ v ] [ 0 ] dp[v][0] dp[v][0]可以枚举v被加入的位置(在k前),之前的子节点必定被加入,即 ∏ ( d p [ u j ] [ 0 ] + d p [ u j ] [ 1 ] ) \prod(dp[u_j][0]+dp[u_j][1]) (dp[uj][0]+dp[uj][1]),在这个节点v被加入,故为 d p [ u j ] [ 2 ] dp[u_j][2] dp[uj][2],后面的节点已满足2情况,或是已经在队列中 即0情况,故为 ∏ ( d p [ u j ] [ 0 ] + d p [ u j ] [ 2 ] ) \prod(dp[u_j][0]+dp[u_j][2]) (dp[uj][0]+dp[uj][2])

d p [ v ] [ 1 ] dp[v][1] dp[v][1]为在和父节点比较时被加入队列,故前面的节点在与v比较时被加入,满足1或者已经被加入,满足0,即 ∏ ( d p [ u j ] [ 0 ] + d p [ u j ] [ 1 ] ) \prod (dp[u_j][0]+dp[u_j][1]) (dp[uj][0]+dp[uj][1]);后面的节点满足父节点被加入,满足2或者比较前已经被加入,满足0,即 ∏ ( d p [ u j ] [ 0 ] + d p [ u j ] [ 2 ] ) \prod (dp[u_j][0]+dp[u_j][2]) (dp[uj][0]+dp[uj][2])

d p [ v ] [ 2 ] dp[v][2] dp[v][2]为父节点被加入,可以枚举v被加入的位置(在k后),之前的子节点必定被加入,即 ∏ ( d p [ u j ] [ 0 ] + d p [ u j ] [ 1 ] ) \prod (dp[u_j][0]+dp[u_j][1]) (dp[uj][0]+dp[uj][1]),在这个节点v被加入,故为 d p [ u j ] [ 2 ] dp[u_j][2] dp[uj][2],后面的节点已满足2情况,或是已经在队列中 即0情况,故为 ∏ ( d p [ u j ] [ 0 ] + d p [ u j ] [ 2 ] ) \prod (dp[u_j][0]+dp[u_j][2]) (dp[uj][0]+dp[uj][2]),或者是v没有被加入,即 ∏ ( d p [ v j ] [ 0 ] + d p [ u j ] [ 1 ] ) \prod (dp[v_j][0]+dp[u_j][1]) (dp[vj][0]+dp[uj][1])

可以得到转移公式 d p [ v ] [ 0 ] = ∑ i = 1 k ( ∏ j = 1 i − 1 ( d p [ u j ] [ 0 ] + d p [ u j ] [ 1 ] ) ∗ d p [ u i ] [ 2 ] ∗ ∏ j = k + 1 l ( d p [ u j ] [ 0 ] + d p [ u j ] [ 2 ] ) dp[v][0]=\sum_{i=1}^{k}(\prod_{j=1}^{i-1}(dp[u_j][0]+dp[u_j][1])*dp[u_i][2]*\prod_{j=k+1}^{l}(dp[u_j][0]+dp[u_j][2]) dp[v][0]=i=1k(j=1i1(dp[uj][0]+dp[uj][1])dp[ui][2]j=k+1l(dp[uj][0]+dp[uj][2])
d p [ v ] [ 1 ] = ∏ j = 1 k ( d p [ u j ] [ 0 ] + d p [ u j ] [ 1 ] ) ∗ ∏ j = k + 1 l ( d p [ v j ] [ 0 ] + d p [ v j ] [ 2 ] ) dp[v][1]=\prod_{j=1}^k(dp[u_j][0]+dp[u_j][1])*\prod_{j=k+1}^{l}(dp[v_j][0]+dp[v_j][2]) dp[v][1]=j=1k(dp[uj][0]+dp[uj][1])j=k+1l(dp[vj][0]+dp[vj][2])
d p [ v ] [ 2 ] = ∑ i = k + 1 l ( ∏ j = 1 i − 1 ( d p [ u j ] [ 0 ] + d p [ u j ] [ 1 ] ) ∗ d p [ u i ] [ 2 ] ∗ ∏ j = k + 1 l ( d p [ u j ] [ 0 ] + d p [ u j ] [ 2 ] ) ) + ∏ i = 1 l ( d p [ u j ] [ 0 ] + d p [ u j ] [ 1 ] ) dp[v][2]=\sum_{i=k+1}^{l}(\prod_{j=1}^{i-1}(dp[u_j][0]+dp[u_j][1])*dp[u_i][2]*\prod_{j=k+1}^{l}(dp[u_j][0]+dp[u_j][2]))+\prod_{i=1}^{l}(dp[u_j][0]+dp[u_j][1]) dp[v][2]=i=k+1l(j=1i1(dp[uj][0]+dp[uj][1])dp[ui][2]j=k+1l(dp[uj][0]+dp[uj][2]))+i=1l(dp[uj][0]+dp[uj][1])

最后答案为1被加入的情况+1没被加入的情况 即 d p [ 1 ] [ 0 ] + d p [ 1 ] [ 1 ] dp[1][0]+dp[1][1] dp[1][0]+dp[1][1]或者 d p [ 1 ] [ 0 ] + d p [ 1 ] [ 2 ] dp[1][0]+dp[1][2] dp[1][0]+dp[1][2]
提高效率可将 d p [ u ] [ 0 ] + d p [ u ] [ 1 ] dp[u][0]+dp[u][1] dp[u][0]+dp[u][1]的前缀积和 d p [ u ] [ 0 ] + d p [ u ] [ 2 ] dp[u][0]+dp[u][2] dp[u][0]+dp[u][2]的后缀积求出。
复杂度 O ( n ) O(n) O(n)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
vector<ll>v[200005];
ll dp[200005][3];
ll f[200005];
ll g[200005];
const int mod=998244353;
void dfs(int p,int u)
{
    //printf("%d %d\n",p,u);
    for(auto x:v[u]){
        if(x!=p){
            dfs(u,x);
        }
    }
    //printf("%d %d\n",p,u);
    int l=v[u].size()-(p?1:0);
    if(l==0){
        dp[u][0]=0;dp[u][1]=1;dp[u][2]=1;
        return;
    }
    int k=0;
    while(k<v[u].size()&&v[u][k]!=p)k++;
    int cnt=1;
    f[0]=1;
    for(int i=0;i<v[u].size();i++){
        int tmp=v[u][i];
        if(tmp==p)continue;
        f[cnt]=f[cnt-1]*(dp[tmp][0]+dp[tmp][1])%mod;
        cnt++;
    }
    cnt=l;
    g[l+1]=1;
    for(int i=v[u].size()-1;i>=0;i--){
        int tmp=v[u][i];
        if(tmp==p)continue;
        g[cnt]=g[cnt+1]*(dp[tmp][0]+dp[tmp][2])%mod;
        cnt--;
    }
    for(int i=1;i<=k;i++){
        dp[u][0]=(dp[u][0]+f[i-1]*g[i+1]%mod*dp[v[u][i-1]][2])%mod;
    }
    dp[u][1]=f[k]*g[k+1]%mod;
    for(int i=k+1;i<=l;i++){
        dp[u][2]=(dp[u][2]+f[i-1]*g[i+1]%mod*dp[v[u][i]][2])%mod;
    }
    dp[u][2]=(dp[u][2]+f[l])%mod;
}
void work()
{
    int n;
    scanf("%d",&n);
    for(int i=1;i<n;i++){
        int x,y;
        scanf("%d%d",&x,&y);
        v[x].push_back(y);
        v[y].push_back(x);
    }
    dfs(0,1);
    //printf("%lld %lld %lld\n",dp[1][0],dp[1][1],dp[1][2]);
    printf("%lld\n",(dp[1][0]+dp[1][1])%mod);
}
int main()
{
    ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    int T;
    //scanf("%d",&T);
    //cin>>T;
    T=1;
    while(T--){
        work();
    }
}

你可能感兴趣的:(codeforces)