Codeforces 504E Misha and LCP on Tree 树链剖分+后缀数组

题意

给一棵树,每个节点上有一个字符。每次询问a到b的路径组成的字符串和c到d的路径组成的字符串的lcp。
n<=300000,q<=1000000

分析

我们可以先重链剖分一下,然后把每条重链正着放一遍反着放一遍,这样就形成了一个字符串,建出这个字符串的后缀数组。
询问的时候,把路径上的O(log)个区间拿出来,然后扔到后缀数组里面求lcp就好了。
时间复杂度 O(nlogn+qlogn) O ( n l o g n + q l o g n )

代码

#include
#include
#include
#include
#include
#include
using namespace std;

const int N=300005;

int n,s[N*2],dep[N],fa[N],size[N],p1[N],p2[N],cnt,last[N],b[N*2],c[N*2],d[N*2],rank[N*4],sa[N*2],rmq[N*2][22],bin[22],top[N],tim,lg[N*2];
char str[N];
struct edge{int to,next;}e[N*2];
struct data{int l,r,len;}inv1[N],inv2[N];

int read()
{
    int x=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

void addedge(int u,int v)
{
    e[++cnt].to=v;e[cnt].next=last[u];last[u]=cnt;
    e[++cnt].to=u;e[cnt].next=last[v];last[v]=cnt;
}

void dfs1(int x)
{
    dep[x]=dep[fa[x]]+1;size[x]=1;
    for (int i=last[x];i;i=e[i].next)
    {
        if (e[i].to==fa[x]) continue;
        fa[e[i].to]=x;
        dfs1(e[i].to);
        size[x]+=size[e[i].to];
    }
}

void dfs2(int x,int chain)
{
    p1[x]=++tim;top[x]=chain;int k=0;
    for (int i=last[x];i;i=e[i].next)
        if (e[i].to!=fa[x]&&size[e[i].to]>size[k]) k=e[i].to;
    if (!k) return;
    dfs2(k,chain);
    for (int i=last[x];i;i=e[i].next)
        if (e[i].to!=fa[x]&&e[i].to!=k) dfs2(e[i].to,e[i].to);
}

void get_sa(int n,int m)
{
    for (int i=1;i<=n;i++) b[s[i]]++;
    for (int i=1;i<=m;i++) b[i]+=b[i-1];
    for (int i=n;i>=1;i--) c[b[s[i]]--]=i;
    int t=0,j=1;
    for (int i=1;i<=n;i++)
    {
        if (s[c[i]]!=s[c[i-1]]) t++;
        rank[c[i]]=t;
    }
    while (j<=n)
    {
        for (int i=1;i<=n;i++) b[i]=0;
        for (int i=1;i<=n;i++) b[rank[i+j]]++;
        for (int i=1;i<=n;i++) b[i]+=b[i-1];
        for (int i=n;i>=1;i--) c[b[rank[i+j]]--]=i;
        for (int i=1;i<=n;i++) b[i]=0;
        for (int i=1;i<=n;i++) b[rank[i]]++;
        for (int i=1;i<=n;i++) b[i]+=b[i-1];
        for (int i=n;i>=1;i--) d[b[rank[c[i]]]--]=c[i];
        t=0;
        for(int i=1;i<=n;i++)
        {
            if (rank[d[i]]!=rank[d[i-1]]||rank[d[i]]==rank[d[i-1]]&&rank[d[i]+j]!=rank[d[i-1]+j]) t++;
            c[d[i]]=t;
        }
        for (int i=1;i<=n;i++) rank[i]=c[i];
        if (t==n) break;
        j<<=1;
    }
    for (int i=1;i<=n;i++) sa[rank[i]]=i;
}

void get_height(int n)
{
    int k=0;
    for (int i=1;i<=n;i++)
    {
        if (k) k--;
        int j=sa[rank[i]-1];
        while (i+k<=n&&j+k<=n&&s[i+k]==s[j+k]) k++;
        rmq[rank[i]][0]=k;
    }
}

void get_rmq(int n)
{
    for (int i=1;i<=n;i++) lg[i]=log(i)/log(2);
    bin[0]=1;
    for (int i=1;i<=lg[n];i++) bin[i]=bin[i-1]*2;
    for (int j=1;j<=lg[n];j++)
        for (int i=1;i+bin[j]-1<=n;i++)
            rmq[i][j]=min(rmq[i][j-1],rmq[i+bin[j-1]][j-1]);
}

int get_mn(int l,int r)
{
    if (l==r) return n*2-l+1;
    l=rank[l];r=rank[r];
    if (l>r) swap(l,r);
    l++;int w=lg[r-l+1];
    return min(rmq[l][w],rmq[r-bin[w]+1][w]);
}

int get_lca(int x,int y)
{
    while (top[x]!=top[y])
    {
        if (dep[top[x]]y]]) swap(x,y);
        x=fa[top[x]];
    }
    return dep[x]y]?x:y;
}

void get(data *a,int &tot,int x,int y)
{
    int lca=get_lca(x,y);
    while (dep[top[x]]>=dep[lca]) a[++tot]=(data){p2[x],p2[top[x]],0},x=fa[top[x]];
    if (dep[x]>=dep[lca]) a[++tot]=(data){p2[x],p2[lca],0};
    int tmp=tot;
    while (dep[top[y]]>dep[lca]) a[++tot]=(data){p1[top[y]],p1[y],0},y=fa[top[y]];
    if (dep[y]>dep[lca]) a[++tot]=(data){p1[lca]+1,p1[y],0};
    reverse(a+tmp+1,a+tot+1);
}

int main()
{
    n=read();scanf("%s",str+1);
    for (int i=1;iint x=read(),y=read();
        addedge(x,y);
    }
    dfs1(1);dfs2(1,1);
    for (int i=1;i<=n;i++) p2[i]=n*2-p1[i]+1,s[p1[i]]=s[p2[i]]=str[i]-'a'+1;
    get_sa(n*2,30);get_height(n*2);get_rmq(n*2);
    int q=read();
    while (q--)
    {
        int a=read(),b=read(),c=read(),d=read(),tot1=0,tot2=0;
        get(inv1,tot1,a,b);get(inv2,tot2,c,d);
        for (int i=1;i<=tot1;i++) inv1[i].len=inv1[i].r-inv1[i].l+1;
        for (int i=1;i<=tot2;i++) inv2[i].len=inv2[i].r-inv2[i].l+1;
        int ans=0,p1=1,p2=1,l1=1,l2=1;
        while (p1<=tot1&&p2<=tot2)
        {
            int len=min(get_mn(inv1[p1].l+l1-1,inv2[p2].l+l2-1),min(inv1[p1].len-l1+1,inv2[p2].len-l2+1));
            ans+=len;
            if (len1,inv2[p2].len-l2+1)) break;
            l1+=len;l2+=len;
            if (l1>inv1[p1].len) p1++,l1=1;
            if (l2>inv2[p2].len) p2++,l2=1;
        }
        printf("%d\n",ans);
    }
    return 0;
}

你可能感兴趣的:(树链剖分,后缀数组)