题目:http://codeforces.com/contest/504/problem/E
快速查询LCP,可以用后缀数组,但树上的字符串不是一个序列;
所以考虑转化成序列—— dfs 序!
普通的 dfs 序中,子树是一段连续的区间,而这里要查询的是链,自然想到树链剖分后的 dfs 序;
这样一条重链在 dfs 序上是一段连续的区间,查询 LCP 时一段一段查询即可,可以用 vector 存下一条路径的所有段;
还要区分方向,所以把 dfs 序得到的字符串再反向复制一遍,为了两串之间不影响,在中间加一个比较小的字符;
预处理ST表可以做到 O(1) 查询两个后缀的 LCP,所以复杂度是预处理 nlogn + 查询 mlogn;
细节比较多...但其实也就是树剖,后缀数组,ST表;
存路径上的段感觉比较麻烦...于是借鉴了一番AC代码,用了 vector,很方便,不过略慢一点。
代码如下:
#include#include #include #include #include #define pb push_back #define mkp make_pair #define pii pair #define fs first #define sc second using namespace std; int const xn=3e5+5,xxn=(xn<<1),xm=1e6+5; int n,hd[xn],ct,to[xn<<1],nxt[xn<<1],top[xn],siz[xn],son[xn],dfn[xn],tim,dep[xn],fa[xn]; int m,mx,tax[xxn],rk[xxn],sa[xxn],tp[xxn],ht[xxn][20],bin[20],r[xxn]; char rs[xn],s[xxn]; vector va,vb; int rd() { int ret=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return f?ret:-ret; } void add(int x,int y){to[++ct]=y; nxt[ct]=hd[x]; hd[x]=ct;} void dfs(int x,int ff) { siz[x]=1; fa[x]=ff; dep[x]=dep[ff]+1; for(int i=hd[x],u;i;i=nxt[i]) { if((u=to[i])==ff)continue; dfs(u,x); siz[x]+=siz[u]; if(siz[u]>siz[son[x]])son[x]=u; } } void dfs2(int x) { dfn[x]=++tim; s[tim]=rs[x]; if(son[x])top[son[x]]=top[x],dfs2(son[x]); for(int i=hd[x],u;i;i=nxt[i]) if((u=to[i])!=fa[x]&&u!=son[x])top[u]=u,dfs2(u); } void Rsort() { for(int i=1;i<=m;i++)tax[i]=0; for(int i=1;i<=mx;i++)tax[rk[tp[i]]]++; for(int i=1;i<=m;i++)tax[i]+=tax[i-1]; for(int i=mx;i;i--)sa[tax[rk[tp[i]]]--]=tp[i]; } void work() { m=128; s[n+1]='@';//s[n+1]! for(int i=1;i<=n;i++)s[mx-i+1]=s[i]; for(int i=1;i<=mx;i++)rk[i]=s[i],tp[i]=i; Rsort(); for(int k=1;k<=mx;k<<=1) { int num=0; for(int i=mx-k+1;i<=mx;i++)tp[++num]=i; for(int i=1;i<=mx;i++) if(sa[i]>k)tp[++num]=sa[i]-k; Rsort(); swap(rk,tp); rk[sa[1]]=1; num=1; for(int i=2;i<=mx;i++) rk[sa[i]]=(tp[sa[i]]==tp[sa[i-1]]&&tp[sa[i]+k]==tp[sa[i-1]+k])?num:++num;//+k if(num==mx)break; m=num; } } void get() { bin[0]=1; for(int i=1;i<20;i++)bin[i]=(bin[i-1]<<1); r[1]=0; for(int i=2;i<=mx;i++)r[i]=r[i>>1]+1; int k=0; for(int i=1;i<=mx;i++)//1 { if(rk[i]==1)continue; if(k)k--; int j=sa[rk[i]-1]; while(i+k<=mx&&j+k<=mx&&s[i+k]==s[j+k])k++;//<=mx!! ht[rk[i]][0]=k; } for(int j=1;j<=19;j++) for(int i=1;i<=mx&&i+bin[j]-1<=mx;i++)//-1 ht[i][j]=min(ht[i][j-1],ht[i+bin[j-1]][j-1]); } int lca(int x,int y) { while(top[x]!=top[y]) { if(dep[top[x]] //top[x],top[y]! x=fa[top[x]]; } return dep[x] x:y; } vector cl() { vector ret,ret2; int x=rd(),y=rd(),L=lca(x,y); while(top[x]!=top[L]) ret.pb(mkp(mx-dfn[x]+1,dfn[x]-dfn[top[x]]+1)),x=fa[top[x]]; ret.pb(mkp(mx-dfn[x]+1,dfn[x]-dfn[L]+1)); while(top[y]!=top[L]) ret2.pb(mkp(dfn[top[y]],dfn[y]-dfn[top[y]]+1)),y=fa[top[y]];//dfn[top[y]]! if(y!=L)ret2.pb(mkp(dfn[L]+1,dfn[y]-dfn[L]));//not include L int siz=ret2.size(); for(int i=siz-1;i>=0;i--)ret.pb(ret2[i]); return ret; } int getlcp(int x,int y) { if(x==y)return mx;//! x=rk[x]; y=rk[y];//! if(x>y)swap(x,y); x++; int w=r[y-x+1]; return min(ht[x][w],ht[y-bin[w]+1][w]);//+1 (h[y][w]->y+bin[w]-1) } int main() { n=rd(); scanf("%s",rs+1); mx=(n<<1)+1; for(int i=1,x,y;i rd(),add(x,y),add(y,x); dfs(1,0); top[1]=1; dfs2(1); work(); get(); int Q=rd(); for(int i=1;i<=Q;i++) { va=cl(); vb=cl(); int ans=0,len=0,p1=0,p2=0,s1=va.size(),s2=vb.size(); while(p1 s2) { len=getlcp(va[p1].fs,vb[p2].fs); len=min(len,min(va[p1].sc,vb[p2].sc)); ans+=len; va[p1].fs+=len; vb[p2].fs+=len; va[p1].sc-=len; vb[p2].sc-=len;//! if(va[p1].sc&&vb[p2].sc)break; if(!va[p1].sc)p1++; if(!vb[p2].sc)p2++; } printf("%d\n",ans); } return 0; }