[BZOJ3611][Heoi2014]大工程(虚树+树形dp)

题目描述

传送门

题解

令size(i)表示i子树里有多少个关键点
令sum(i)表示i子树中所有关键点到i的距离和
令Max(i)表示i子树中所有关键点到它的最长链,_Max(i)次长链,Min(i)最短链,_Min(i)次短链
这些都非常好维护,第二问和第三问也很好计算,用最和次拼一下就行了
对于第一问的话,在dp的时候维护一下当前size和sum的乘积就行了
将所有的关键点和它们的lca建出一棵虚树,边权为两点之间的距离
然后按照上面的dp就行了
dp的时候要格外注意子树的根是否是关键点以及儿子的个数

代码

#include
#include
#include
#include
#include
using namespace std;
#define LL long long
#define N 1000005
#define sz 20

int n,q,k,dfs_clock,top;
int tot,point[N],nxt[N*2],v[N*2],c[N*2];
int pt[N],key[N],flag[N],stack[N],h[N],in[N],out[N],f[N][sz+3],size[N];
LL sum[N],Max[N],_Max[N],Min[N],_Min[N],ans1,ans2,ans3;

const LL inf=1e18;

void add(int x,int y,int z)
{
    ++tot; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; c[tot]=z;
}
void build(int x,int fa)
{
    h[x]=h[fa]+1;in[x]=++dfs_clock;
    for (int i=1;i1]][i-1];
    for (int i=point[x];i;i=nxt[i])
        if (v[i]!=fa)
        {
            f[v[i]][0]=x;
            build(v[i],x);
        }
    out[x]=++dfs_clock;
}
int cmp(int a,int b)
{
    return in[a]int lca(int x,int y)
{
    if (h[x]int cha=h[x]-h[y];
    for (int i=0;iif ((cha>>i)&1) x=f[x][i];
    if (x==y) return x;
    for (int i=sz-1;i>=0;--i)
        if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
    return f[x][0];
}
void treedp(int x)
{
    size[x]=0;
    sum[x]=0;
    Max[x]=_Max[x]=0;
    Min[x]=_Min[x]=inf;
    if (key[x]==q) size[x]=1,Min[x]=0;
    int cnt=0;
    for (int i=point[x];i;i=nxt[i])
    {
        ++cnt;
        treedp(v[i]);
        ans1+=(LL)size[v[i]]*sum[x]+(LL)size[x]*(sum[v[i]]+(LL)c[i]*(LL)size[v[i]]);
        size[x]+=size[v[i]];
        sum[x]+=sum[v[i]]+(LL)c[i]*(LL)size[v[i]];
        if (Max[v[i]]+(LL)c[i]>Max[x])
        {
            _Max[x]=Max[x];
            Max[x]=Max[v[i]]+(LL)c[i];
        }
        else _Max[x]=max(_Max[x],Max[v[i]]+(LL)c[i]);
        if (Min[v[i]]+(LL)c[i]else _Min[x]=min(_Min[x],Min[v[i]]+(LL)c[i]);
    }
    if (key[x]==q||cnt>1)
    {
        ans2=min(ans2,Min[x]+_Min[x]);
        ans3=max(ans3,Max[x]+_Max[x]);
    }
    point[x]=0;
}
int main()
{
    scanf("%d",&n);
    for (int i=1;iint x,y;scanf("%d%d",&x,&y);
        add(x,y,1),add(y,x,1);
    }
    build(1,0);
    memset(point,0,sizeof(point));
    scanf("%d",&q);
    while (q)
    {
        scanf("%d",&k);
        for (int i=1;i<=k;++i)
        {
            scanf("%d",&pt[i]);
            key[pt[i]]=flag[pt[i]]=q;
        }
        sort(pt+1,pt+k+1,cmp);pt[0]=k;
        for (int i=2;i<=k;++i)
        {
            int r=lca(pt[i-1],pt[i]);
            if (flag[r]!=q)
            {
                flag[r]=q;
                pt[++pt[0]]=r;
            }
        }
        if (flag[1]!=q) flag[1]=q,pt[++pt[0]]=1;
        sort(pt+1,pt+pt[0]+1,cmp);
        tot=0;stack[top=1]=1;
        for (int i=2;i<=pt[0];++i)
        {
            while (in[pt[i]]stack[top]]||in[pt[i]]>out[stack[top]])
                --top;
            add(stack[top],pt[i],h[pt[i]]-h[stack[top]]);
            stack[++top]=pt[i];
        }
        ans1=0;ans2=inf;ans3=0;
        treedp(1);
        printf("%lld %lld %lld\n",ans1,ans2,ans3);
        --q;
    }
}

你可能感兴趣的:(题解,dp,省选,虚树)