[BZOJ3572][Hnoi2014]世界树(虚树+树形dp+二分+lca)

题目描述

传送门

题解

首先建出虚树来,边权即为原树上的距离
这题我dp的方法非常蠢
f(i)表示从i的父边出去(必须经过i的父亲)到达的关键点的最短路
fp(i)表示最短路的点
g(i)表示i到i的子树中到达的关键点的最短路
gp(i)表示最短路的点
然后这两个互相转移一下…

dp完了之后枚举虚树上的每一条边(u,v)
因为已经知道了从u出去到关键点的最短路和从v出去到关键点的最短路
然后就可以计算出这条边上的哪些点归到u的关键点去,哪些归到v的关键点去
嫌麻烦写了二分…
不过恶心的是统计答案
因为建出来的虚树只包含了关键点和其lca,在每一个点和每一条边上都有可能有若干没有关键点的子树
维护一个cnt表示虚树上每一个点所连的非关键点有多少个(包括子树)
然后根据原树的size和维护出的cnt乱搞…

调了两节课…午饭都没吃…恶心死我了
可能根本原因是我太蠢了

代码

#include
#include
#include
#include
#include
using namespace std;
#define N 600005
#define sz 19

int n,q,m,dfs_clock,inf;
int tot,point[N],nxt[N],v[N],last[N];
int in[N],out[N],h[N],size[N],cnt[N],stack[N],top,st[N][sz+3],key[N],flag[N];
int qu[N],f[N],fp[N],g[N],gp[N],Son[N],pt[N],goal[N],ans[N];

void add(int x,int y)
{
    ++tot; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;
}
void build(int x,int fa)
{
    in[x]=++dfs_clock;h[x]=h[fa]+1;size[x]=1;
    for (int i=1;i1]][i-1];
    for (int i=point[x];i;i=nxt[i])
        if (v[i]!=fa)
        {
            st[v[i]][0]=x;
            build(v[i],x);
            size[x]+=size[v[i]];
        }
    out[x]=dfs_clock;
}
int cmp(int a,int b)
{
    return in[a]<in[b];
}
int lca(int x,int y)
{
    if (h[x]int k=h[x]-h[y];
    for (int i=0;iif ((k>>i)&1) x=st[x][i];
    if (x==y) return x;
    for (int i=sz-1;i>=0;--i)
        if (st[x][i]!=st[y][i]) x=st[x][i],y=st[y][i];
    return st[x][0];
}
void sontofa(int x)
{
    if (key[x]==q) g[x]=0,gp[x]=x;
    for (int i=point[x];i;i=nxt[i])
    {
        sontofa(v[i]);
        if (g[v[i]]+last[v[i]]else if (g[v[i]]+last[v[i]]==g[x]&&gp[v[i]]int find(int x,int k)
{
    for (int i=0;iif ((k>>i)&1) x=st[x][i];
    return x;
}
void calc(int len,int fa,int son)
{
    int minson=g[son],minfa=f[son]-len;
    if (minson&&minfa&&(len+minfaint p=find(son,len-1);
        goal[fa]=goal[son]=fp[son];
        ans[fp[son]]+=size[p]-size[son];
        return;
    }
    if (minson&&minfa&&(minson+lenint p=find(son,len-1);
        goal[fa]=goal[son]=gp[son];
        ans[gp[son]]+=size[p]-size[son];
        return;
    }
    int l=-1,r=len,mid,final=-1;
    while (l<=r)
    {
        if (l==r&&l==-1)
        {
            final=-1;
            break;
        }
        mid=(l+r)>>1;
        if (mid+minson1;
        else r=mid-1;
    }
    if (final==len)
    {
        int p=find(son,len-1);
        goal[fa]=goal[son]=gp[son];
        ans[gp[son]]+=size[p]-size[son];
        return;
    }
    if (final==-1)
    {
        int p=find(son,len-1);
        goal[fa]=goal[son]=fp[son];
        ans[fp[son]]+=size[p]-size[son];
        return;
    }
    int p=find(son,len-1),t=find(son,final);
    goal[fa]=fp[son];goal[son]=gp[son];
    ans[fp[son]]+=size[p]-size[t];
    ans[gp[son]]+=size[t]-size[son];
    return;
}
void fatoson(int x)
{
    if (key[x]==q)
    {
        for (int i=point[x];i;i=nxt[i])
            f[v[i]]=last[v[i]],fp[v[i]]=x;
    }
    else
    {
        Son[0]=0;
        for (int i=point[x];i;i=nxt[i]) Son[++Son[0]]=v[i];

        int Min=inf,Minp=0;
        for (int i=1;i<=Son[0];++i)
        {
            if (f[x]+last[Son[i]]else if (f[x]+last[Son[i]]==f[Son[i]]&&fp[x]if (Min+last[Son[i]]else if (Min+last[Son[i]]==f[Son[i]]&&Minpif (g[Son[i]]+last[Son[i]]else if (g[Son[i]]+last[Son[i]]==Min&&gp[Son[i]]0;
        for (int i=Son[0];i>=1;--i)
        {
            if (Min+last[Son[i]]else if (Min+last[Son[i]]==f[Son[i]]&&Minpif (g[Son[i]]+last[Son[i]]else if (g[Son[i]]+last[Son[i]]==Min&&gp[Son[i]]for (int i=point[x];i;i=nxt[i])
    {
        fatoson(v[i]);
        calc(last[v[i]],x,v[i]);
    }
}
void solve()
{
    sontofa(1);
    fatoson(1);
    for (int i=1;i<=pt[0];++i)
        ans[goal[pt[i]]]+=cnt[pt[i]];
    for (int i=1;i<=m;++i)
        printf("%d ",ans[qu[i]]);puts("");
}
int main()
{
    scanf("%d",&n);
    for (int i=1;iint x,y;scanf("%d%d",&x,&y);
        add(x,y),add(y,x);
    }
    build(1,0);
    scanf("%d",&q);
    memset(point,0,sizeof(point));
    memset(g,127,sizeof(g));inf=g[0];
    memset(f,127,sizeof(f));
    while (q)
    {
        scanf("%d",&m);
        for (int i=1;i<=m;++i)
        {
            scanf("%d",&pt[i]);
            qu[i]=pt[i];
            key[pt[i]]=flag[pt[i]]=q;
            goal[pt[i]]=pt[i];
        }
        sort(pt+1,pt+m+1,cmp);pt[0]=m;
        for (int i=2;i<=m;++i)
        {
            int r=lca(pt[i-1],pt[i]);
            if (flag[r]!=q)
                flag[r]=q,pt[++pt[0]]=r;
        }
        if (flag[1]!=q) flag[1]=q,pt[++pt[0]]=1;
        sort(pt+1,pt+pt[0]+1,cmp);
        tot=0;stack[top=1]=1;cnt[1]=size[1];
        for (int i=2;i<=pt[0];++i)
        {
            while (in[pt[i]]<in[stack[top]]||in[pt[i]]>out[stack[top]])
                --top;
            add(stack[top],pt[i]);
            last[pt[i]]=h[pt[i]]-h[stack[top]];
            int r=find(pt[i],h[pt[i]]-h[stack[top]]-1);
            cnt[stack[top]]-=size[r];
            stack[++top]=pt[i];cnt[pt[i]]=size[pt[i]];
        }
        solve();
        for (int i=1;i<=pt[0];++i)
        {
            int x=pt[i];
            point[x]=fp[x]=gp[x]=cnt[x]=ans[x]=last[x]=goal[x]=0;
            f[x]=g[x]=inf;
        }
        --q;
    }
}

你可能感兴趣的:(题解,dp,lca,省选,二分,虚树)