[BZOJ3910]火车(lca+树链剖分)

题目描述

传送门

题解

lca+树链剖分裸题

代码

#include
#include
#include
#include
#include
using namespace std;
#define N 500005
#define sz 19

int n,m,x,y,r,now,dfs_clock,a[N];
int tot,point[N],nxt[N*2],v[N*2];
int h[N],f[N][sz+5],father[N],size[N],son[N],top[N],num[N];
bool flag[N*4],delta[N*4];
long long ans;

void add(int x,int y)
{
    ++tot; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;
}
void build(int x,int fa)
{
    h[x]=h[fa]+1;father[x]=fa;size[x]=1;
    for (int i=1;i1]][i-1];
    for (int i=point[x];i;i=nxt[i])
        if (v[i]!=fa)
        {
            f[v[i]][0]=x;
            build(v[i],x);
            size[x]+=size[v[i]];
            if (size[v[i]]>size[son[x]]) son[x]=v[i];
        }
}
void dfs(int x,int fa)
{
    if (x==son[fa]) top[x]=top[fa];
    else top[x]=x;
    num[x]=++dfs_clock;
    if (son[x]) dfs(son[x],x);
    for (int i=point[x];i;i=nxt[i])
        if (v[i]!=fa&&v[i]!=son[x])
            dfs(v[i],x);
}
int lca(int x,int y)
{
    if (h[x]int k=h[x]-h[y];
    for (int i=0;iif ((k>>i)&1) x=f[x][i];
    if (x==y) return x;
    for (int i=sz-1;i>=0;--i)
        if (f[x][i]!=f[y][i])
            x=f[x][i],y=f[y][i];
    return f[x][0];
}
void pushdown(int now,int l,int r,int mid)
{
    if (delta[now])
    {
        flag[now<<1]=delta[now<<1]=1;
        flag[now<<1|1]=delta[now<<1|1]=1;
        delta[now]=0;
    }
}
void change(int now,int l,int r,int lr,int rr)
{
    int mid=(l+r)>>1;
    if (lr<=l&&r<=rr)
    {
        flag[now]=1;
        delta[now]=1;
        return;
    }
    pushdown(now,l,r,mid);
    if (lr<=mid) change(now<<1,l,mid,lr,rr);
    if (mid+1<=rr) change(now<<1|1,mid+1,r,lr,rr);
}
int query(int now,int l,int r,int x)
{
    int mid=(l+r)>>1;
    if (l==r) return flag[now];
    pushdown(now,l,r,mid);
    if (x<=mid) return query(now<<1,l,mid,x);
    else return query(now<<1|1,mid+1,r,x);
}
void CHANGE(int u,int t)
{
    int f1=top[u],f2=top[t];
    while (f1!=f2)
    {
        if (h[f1]1,1,n,num[f1],num[u]);
        u=father[f1];
        f1=top[u];
    }
    if (num[u]>num[t]) swap(u,t);
    change(1,1,n,num[u],num[t]);
}
int main()
{
    scanf("%d%d%d",&n,&m,&now);
    for (int i=1;i"%d%d",&x,&y);
        add(x,y),add(y,x);
    }
    for (int i=1;i<=m;++i) scanf("%d",&a[i]);
    build(1,0);
    dfs(1,0);
    CHANGE(now,now);
    for (int i=1;i<=m;++i)
    {
        x=i;
        while (x<=m&&query(1,1,n,num[a[x]])) ++x;
        if (x>m) break;
        r=lca(now,a[x]);
        ans+=(long long)h[now]-h[r]+h[a[x]]-h[r];
        CHANGE(now,a[x]);
        now=a[x];i=x;
    }
    printf("%lld\n",ans);
}

你可能感兴趣的:(题解,lca,树链剖分)