这应该是一道很裸的虚树了吧。。。
(最近发现把树形dp的递归形式改成按dfs序列排序后倒序操作会变快!!O(NlogN)<O(N)23333。。)
首先构造出关键点的虚树。
对于虚树中的每一个点,用sum[x]表示所有以x为顶点的链的总长度,f[x]表示以x为顶点的链的最小值,g[x]表示最大值。显然f[x]和g[x]的转移是很方便的,更新答案也很方便。而sum[x]的转移也是很方便的,关键是如何用sum[x]更新答案。
令y=fa[x](指虚树中的fa),直接求是很玛法的,但是我们可以求出在经过y的链中,sum[x]的贡献,实际上就是(sz[y]-sz[x])*sum[x],这样就要先弄出sz[]了,实际上可以不用,具体见下面的代码。
AC代码如下:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define inf 1000000000 #define N 1000005 #define ll long long using namespace std; int n,m,tot,dfsclk,fst[N],pnt[N<<1],nxt[N<<1],bin[25],pos[N],fa[N][20],d[N]; int a[N],p[N],sz[N],f[N],g[N],len[N],anc[N],q[N]; ll sum[N]; bool bo[N]; int read(){ int x=0; char ch=getchar(); while (ch<'0' || ch>'9') ch=getchar(); while (ch>='0' && ch<='9'){ x=x*10+ch-'0'; ch=getchar(); } return x; } void add(int x,int y){ pnt[++tot]=y; nxt[tot]=fst[x]; fst[x]=tot; } void dfs(int x){ pos[x]=++dfsclk; int i,p; for (i=1; bin[i]<=d[x]; i++) fa[x][i]=fa[fa[x][i-1]][i-1]; for (p=fst[x]; p; p=nxt[p]){ int y=pnt[p]; if (y!=fa[x][0]){ fa[y][0]=x; d[y]=d[x]+1; dfs(y); } } } int lca(int x,int y){ if (d[x]<d[y]) swap(x,y); int tmp=d[x]-d[y],i; for (i=0; bin[i]<=tmp; i++) if (tmp&bin[i]) x=fa[x][i]; for (i=19; i>=0; i--) if (fa[x][i]!=fa[y][i]){ x=fa[x][i]; y=fa[y][i]; } return (x==y)?x:fa[x][0]; } bool cmp(int x,int y){ return pos[x]<pos[y]; } void solve(){ m=read(); int i,tp=0,cnt=m; for (i=1; i<=m; i++){ p[i]=a[i]=read(); bo[a[i]]=1; } sort(a+1,a+m+1,cmp); for (i=1; i<=m; i++) if (!tp){ q[++tp]=a[i]; anc[a[i]]=0; } else{ int tmp=lca(a[i],q[tp]); for (; d[q[tp]]>d[tmp]; tp--) if (d[q[tp-1]]<=d[tmp]) anc[q[tp]]=tmp; if (q[tp]!=tmp){ anc[tmp]=q[tp]; q[++tp]=tmp; p[++cnt]=tmp; } anc[a[i]]=tmp; q[++tp]=a[i]; } sort(p+1,p+cnt+1,cmp); for (i=1; i<=cnt; i++){ int x=p[i]; len[x]=d[x]-d[anc[x]]; if (bo[x]){ sz[x]=1; f[x]=g[x]=0; } else{ sz[x]=0; f[x]=inf; g[x]=-inf; } sum[x]=0; } ll t1=0; int t2=inf,t3=-inf; for (i=cnt; i>1; i--){ int x=p[i],y=anc[x]; t1+=(sum[x]+(ll)len[x]*sz[x])*sz[y]+sum[y]*sz[x]; sz[y]+=sz[x]; sum[y]+=sum[x]+(ll)len[x]*sz[x]; t2=min(t2,f[y]+f[x]+len[x]); f[y]=min(f[y],f[x]+len[x]); t3=max(t3,g[y]+g[x]+len[x]); g[y]=max(g[y],g[x]+len[x]); } for (i=1; i<=m; i++) bo[a[i]]=0; printf("%lld %d %d\n",t1,t2,t3); } int main(){ n=read(); int i; bin[0]=1; for (i=1; i<=20; i++) bin[i]=bin[i-1]<<1; for (i=1; i<n; i++){ int x=read(),y=read(); add(x,y); add(y,x); } dfs(1); int cas=read(); while (cas--) solve(); return 0; }
by lych
2016.3.7