[AGC008F] Black Radius(树形dp)

神题啊!!

Description

给你一棵有N个节点的树,节点编号为1到N,所有边的长度都为1

“全”对某些节点情有独钟,这些他喜欢的节点的信息会以一个长度为N的字符串s的形式给到你,具体一点就是对于1<=i<=N,si=1表示”全”喜欢节点i,为0表示”全”不喜欢节点i

一开始的时候,所有的节点都是白色的,”全”会进行以下操作恰好一次:

选择一个他喜欢的节点v和一个非负整数d,然后将所有与节点v距离不超过d的节点全部涂黑

问进行操作之后,有多少种不同的涂色情况?两种情况不同当且仅当两种情况存在一个节点i的颜色不同

Input

第一行一个正整数N

接下来N−1行每行两个正整数xi,yi表示xi到yi有一条边

最后一行一个字符串s

Output

输出不同染色情况的数量

题解:

f(i,d) f ( i , d ) ,为把距 i i 点小于等于 d d 染成黑色的集合。

不过这样染色的时候会有重复的方案,我们考虑下何时会重复。

f(i,d1) f ( i , d 1 ) f(j,d2) f ( j , d 2 ) ,重复时,当且仅当 i i j j 之间存在一点 k k ,使得 f(k,d1dist(i,k)) f ( k , d 1 − d i s t ( i , k ) ) dist(i,j) d i s t ( i , j ) 表示 i i j j 的树上距离。

这个很显然吧……

所以对于某个 f(i,d) f ( i , d ) ,如果有 f(k,d1dist(i,k))=f(i,d) f ( k , d 1 − d i s t ( i , k ) ) = f ( i , d ) ,那么 f(i,d) f ( i , d ) 就会被重复算。

对于每一个 i i ,我们只用考虑更他相邻的点和他的重复情况就可以了。(也很显然吧)

那我们再考虑一下,何时 f(i,d)=f(k,d1) f ( i , d ) = f ( k , d − 1 )

还是很显然,如果以 k k 为根, i i 的子树全部被染成黑色的话,他们就相等。

所以可以得知, d d 的上界为 i i 到他子树中的最远点的距离。

那我们用树形dp, O(n) O ( n ) 算出,然后暴力枚举每个 i i ,算 d d 的上界。

那下界呢? 如果 i i 点是特殊点,无疑是0,那如果不是呢?

如果不是,那么若存在一个特殊节点j满足方案 (i,d) ( i , d ) j j 所在子树内所有节点均被染成黑色, (i,d) ( i , d ) 就是一个合法的染色方案。故我们只需要求出从 i i 出发的至少经过 1 个特殊节点到达 j j 子树中的最远节点的距离的最小值,就是可行的 d d 的最小值。

知道每个 i i d d 的上下界,那就算吧!

CODE:

#include
#include
using namespace std;

int d1[200005],d2[200005];
int d3[200005],d4[200005];
//d1:子树中的最远距离
//d2:非子树中的最远距离
//d3:这个节点到子树中至少经过1个特殊节点到达子树中的最远节点最小距离 
//d4:从他父亲出发,不经过这棵子树的最远路径 
int n,x,y,siz[200005],col[200005],fa[200005];
long long ans=0;
int tot=0,h[200005];
struct Edge{
    int x,next;
}e[400005];
char ch[200005];

inline void add_edge(int x,int y){
    e[++tot].x=y;
    e[tot].next=h[x],h[x]=tot;
}

void dfs1(int x,int father){
    siz[x]=col[x],fa[x]=father;
    d3[x]=col[x]?0:1e9;
    for(int i=h[x];i;i=e[i].next){
        if(e[i].x==father)continue;
        dfs1(e[i].x,x);
        siz[x]+=siz[e[i].x];
        d1[x]=max(d1[x],d1[e[i].x]+1);
        if(siz[e[i].x])
            d3[x]=min(d3[x],d1[e[i].x]+1);
    }
}

void dfs2(int x,int father){
    if(father)d4[x]=d2[x]-1;
    int maxn=0,sec=0;
    for(int i=h[x];i;i=e[i].next){
        if(e[i].x==father)continue;
        if(d1[e[i].x]+1>maxn)
            sec=maxn,maxn=d1[e[i].x]+1;
        else sec=max(sec,d1[e[i].x]+1);
    }
    for(int i=h[x];i;i=e[i].next){
        if(e[i].x==father)continue;
        if(d1[e[i].x]+1==maxn)
            d2[e[i].x]=max(d2[x],sec)+1;
        else d2[e[i].x]=max(d2[x],maxn)+1;
        dfs2(e[i].x,x);
    }
}

int main(){
    scanf("%d",&n);
    for(int i=1;iscanf("%d%d",&x,&y);
        add_edge(x,y);
        add_edge(y,x);
    }
    scanf("%s",ch+1);
    for(int i=1;ch[i];i++)
        col[i]=ch[i]-'0';
    dfs1(1,0),dfs2(1,0);
    for(int i=1;i<=n;i++){
        int minv,maxv;
        minv=min(d3[i],siz[1]==siz[i]?(int)1e9:d2[i]);
        maxv=max(d1[i],d2[i])-1;
        for(int j=h[i];j;j=e[j].next){
            if(e[j].x==fa[i])maxv=min(maxv,d1[i]+1);
            else maxv=min(maxv,d4[e[j].x]+1);
        }
        if(maxv>=minv)ans+=1LL*maxv-minv+1;
    }
    printf("%lld",ans+1);
}

你可能感兴趣的:(题解)