给你一棵有 n n n个节点的树,并用 01 01 01串告诉你哪些节点上有棋子(恰好一棵)。
你可以进行若干次操作,每次操作可以将两颗距离至少为 2 2 2的棋子向彼此移动一步。
问能否通过若干次操作使得所有的棋子都在一个点上。如果能,输出最小操作次数;否则,输出 − 1 -1 −1。
1 ≤ n ≤ 1 0 6 1\leq n\leq 10^6 1≤n≤106
时间限制 2000 m s 2000ms 2000ms,空间限制 256 M B 256MB 256MB。
看这道题之前可以先看看AGC034E Complete Compress,本题为其加强版。
我们可以枚举最后所有棋子聚集在哪个点,设这个点为 r r r,我们设 r r r为根。
设 d i s u dis_u disu表示 u u u的子树中每个棋子到 u u u的距离,那每一次操作会使 d i s r dis_r disr减少 2 2 2或者不变。我们发现, 如果不变的话,相当于浪费了一次,所以最优的肯定是选择减少 2 2 2。
每次减少 2 2 2,要不是在 r r r的一个儿子的 d i s dis dis值减少 2 2 2,要不是在 r r r的两个儿子分别减少 1 1 1。我们考虑什么时候无解,无解就是子树不能抵消完。设 m n u mn_u mnu表示子树 u u u中的棋子在内部操作若干次,直到不能再操作时的 d i s u dis_u disu(也就是需要与其他子树操作的最小次数)。设 v v v为 u u u的儿子,我们比较 m n v + s i z v mn_v+siz_v mnv+sizv和 d i s u − d i s v − s i z v dis_u-dis_v-siz_v disu−disv−sizv的大小( s i z v siz_v sizv表示子树 v v v中的棋子个数):
我们记录 u u u的所有儿子 v v v的 m n v + d i s v + 2 s i z v mn_v+dis_v+2siz_v mnv+disv+2sizv的最大值 m x u mx_u mxu,然后用这个最大值来与 d i s u dis_u disu作比较即可。
最后,看 m n r mn_r mnr是否为 0 0 0。如果为 0 0 0,则可以抵消完,则用 d i s r / 2 dis_r/2 disr/2来更新答案(这里的 d i s r dis_r disr是最开始的 d i s r dis_r disr);否则,以 r r r为最终聚集的点是无解的。
时间复杂度为 O ( n 2 ) O(n^2) O(n2)。下面考虑优化。
我们可以用换根 D P DP DP。设 s u m u sum_u sumu表示所有棋子到 u u u的距离,然后和上面类似地维护 f a d i s u fadis_u fadisu表示以 u u u为根节点,除去 u u u的子树部分之外的所有棋子到 u u u的距离, f a m n u famn_u famnu表示以 u u u为根节点时除去 u u u子树部分之外的部分内部操作若干次,直到不能再操作时的 d i s u dis_u disu。其实和上面的定义是类似的,只不过上面是在 u u u的子树中,这里是在 u u u的子树之外。
因为在更新 u u u的儿子 v v v的 f a m n v famn_v famnv时有可能用到自己的 m n v + d i s v + 2 s i z v mn_v+dis_v+2siz_v mnv+disv+2sizv,所以我们在记录 m n v + d i s v + 2 s i z v mn_v+dis_v+2siz_v mnv+disv+2sizv的最大值时还要记录次大值, m x u , 0 / 1 mx_{u,0/1} mxu,0/1分别表示最大值和次大值。
用与上面类似的方法来维护,然后对每个点判断是否有解,有解就更新答案。
时间复杂度为 O ( n ) O(n) O(n)。
可以参考代码帮助理解。
#include
using namespace std;
const int N=1000000;
int n,tot=0,d[2*N+5],l[2*N+5],r[N+5],dep[N+5],siz[N+5];
long long ans=1e18,dis[N+5],mn[N+5],fadis[N+5],sum[N+5],famn[N+5],mx[N+5][2];
char s[N+5];
void add(int xx,int yy){
l[++tot]=r[xx];d[tot]=yy;r[xx]=tot;
}
void dfs1(int u,int fa){
siz[u]=(s[u]=='1');
for(int i=r[u];i;i=l[i]){
int v=d[i];
if(v==fa) continue;
dfs1(v,u);
siz[u]+=siz[v];
dis[u]+=dis[v]+siz[v];
long long tmp=dis[v]+siz[v]*2+mn[v];
if(tmp>mx[u][0]){
mx[u][1]=mx[u][0];mx[u][0]=tmp;
}
else if(tmp>mx[u][1]) mx[u][1]=tmp;
}
if(mx[u][0]<=dis[u]) mn[u]=dis[u]%2;
else mn[u]=mx[u][0]-dis[u];
}
void dfs2(int u,int fa){
for(int i=r[u];i;i=l[i]){
int v=d[i];
if(d[i]==fa) continue;
long long now=dis[u]-dis[v]-siz[v];
fadis[v]=now+siz[1]-siz[v]+fadis[u];
sum[v]=dis[v]+fadis[v];
long long famx=mx[u][0],fasum=sum[u]-dis[v]-siz[v];
if(famx==dis[v]+siz[v]*2+mn[v]) famx=mx[u][1];
famx=max(famx,famn[u]+fadis[u]);
if(famx<=fasum) famn[v]=fasum%2+siz[1]-siz[v];
else famn[v]=famx-fasum+siz[1]-siz[v];
long long mxd=max(mx[v][0],fadis[v]+famn[v]);
if(mxd<=sum[v]&&sum[v]%2==0) ans=min(ans,sum[v]/2);
dfs2(v,u);
}
}
int main()
{
// freopen("charlotte.in","r",stdin);
// freopen("charlotte.out","w",stdout);
scanf("%d",&n);
scanf("%s",s+1);
for(int i=1,x,y;i<n;i++){
scanf("%d%d",&x,&y);
add(x,y);add(y,x);
}
dfs1(1,0);
if(mn[1]==0) ans=dis[1]/2;
sum[1]=dis[1];
dfs2(1,0);
if(ans==1e18) printf("-1\n");
else printf("%lld\n",ans);
return 0;
}