题目大意:给定一棵树,m次询问,每次给出k个关键点,询问这k个点之间的两两距离和、最小距离和最大距离
n<=100W,m<=50000,Σk<=2*n
处理方法同2286 消耗战 地址见 http://blog.csdn.net/popoqqq/article/details/42493725
这个题的DP有些麻烦 因此我把要处理的节点单独拎出来做的DP 具体状态和转移见代码
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> #define M 1001001 #define INF 0x3f3f3f3f using namespace std; struct abcd{ int to,next; }table[M<<1]; int head[M],tot; int n,m; int pos[M],dpt[M],fa[M][20]; long long ans,ans_min,ans_max; bool is_key_point[M]; void Add(int x,int y) { table[++tot].to=y; table[tot].next=head[x]; head[x]=tot; } void DFS(int x) { static int cnt=0; int i; pos[x]=++cnt;dpt[x]=dpt[fa[x][0]]+1; for(i=head[x];i;i=table[i].next) if(table[i].to!=fa[x][0]) { fa[table[i].to][0]=x; DFS(table[i].to); } } int LCA(int x,int y) { int j; if(dpt[x]<dpt[y]) swap(x,y); for(j=19;~j;j--) if(dpt[fa[x][j]]>=dpt[y]) x=fa[x][j]; if(x==y) return x; for(j=19;~j;j--) if(fa[x][j]!=fa[y][j]) x=fa[x][j],y=fa[y][j]; return fa[x][0]; } bool Compare(int x,int y) { return pos[x] < pos[y] ; } void Tree_DP(int x,int from) { static long long f[M],g[M],max_dis[M],min_dis[M]; //f[x]表示以x为根的子树中有多少关键点 //g[x]表示以x为根的子树中所有关键点到x的距离之和 //max_dis[x]/min_dis[x]表示节点x为根的子树中的关键点到x的距离的最大/最小值 int i; f[x]=is_key_point[x];g[x]=0; max_dis[x]=(is_key_point[x]?0:-INF); min_dis[x]=(is_key_point[x]?0:INF); for(i=head[x];i;i=table[i].next) { if(table[i].to==from) continue; Tree_DP(table[i].to,x); int dis=dpt[table[i].to]-dpt[x]; ans+=(g[x]+f[x]*dis)*f[table[i].to]+g[table[i].to]*f[x]; ans_min=min(ans_min,min_dis[x]+min_dis[table[i].to]+dis); ans_max=max(ans_max,max_dis[x]+max_dis[table[i].to]+dis); f[x]+=f[table[i].to]; g[x]+=g[table[i].to]+f[table[i].to]*dis; max_dis[x]=max(max_dis[x],max_dis[table[i].to]+dis); min_dis[x]=min(min_dis[x],min_dis[table[i].to]+dis); } } int main() { int i,j,k,x,y; cin>>n; for(i=1;i<n;i++) { scanf("%d%d",&x,&y); Add(x,y),Add(y,x); } DFS(1); for(j=1;j<=19;j++) for(i=1;i<=n;i++) fa[i][j]=fa[fa[i][j-1]][j-1]; cin>>m; for(i=1;i<=m;i++) { static int a[M]; static int stack[M],top; scanf("%d",&k); for(j=1;j<=k;j++) scanf("%d",&a[j]); sort(a+1,a+k+1,Compare); tot=0; stack[top=1]=1; head[1]=0;is_key_point[1]=(a[1]==1); for(j=1;j<=k;j++) { int lca=LCA(a[j],stack[top]); while(dpt[lca]<dpt[stack[top]]) { if(dpt[stack[top-1]]<=dpt[lca]) { int temp=stack[top--]; if(stack[top]!=lca) { stack[++top]=lca; head[lca]=0; is_key_point[lca]=0; } Add(lca,temp); break; } Add(stack[top-1],stack[top]); stack[top--]=0; } if(stack[top]!=a[j]) { stack[++top]=a[j]; head[a[j]]=0; } is_key_point[a[j]]=1; } while(top>1) Add(stack[top-1],stack[top]),top--; ans=0;ans_min=INF;ans_max=-INF; Tree_DP(1,0); #ifdef ONLINE_JUDGE printf("%lld %lld %lld\n",ans,ans_min,ans_max); #else printf("%I64d %I64d %I64d\n",ans,ans_min,ans_max); #endif } return 0; }