题目:https://loj.ac/problem/3055
先写了暴力。本来想的是 n<=300 的那个在树上暴力维护好整个字符串, x=1 的那个用主席树维护好字符串和 nxt 数组。但 x=1 的部分会 TLE ,而且似乎不太对的样子。
#include#include #include #include #define ll long long #define pb push_back #define ls Ls[cr] #define rs Rs[cr] using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } int n; namespace S1{ const int N=305,M=N*N; int fa[N],q[N],s[M],nxt[M]; ll sm[N]; vector<int> c[N],nt[N]; void add(int x,int y,int m=0,int ch=0) { sm[y]=sm[x]; fa[y]=x; if(!m)return; int top=0, cr=x; while(cr)q[++top]=cr,cr=fa[cr]; int tot=0; for(int i=top;i;i--) { cr=q[i]; for(int j=0,lm=c[cr].size();j ) s[++tot]=c[cr][j], nxt[tot]=nt[cr][j]; } c[y].resize(m); nt[y].resize(m); int i,j; if(!tot){ s[1]=c[y][0]=ch;i=2;j=2;} else { i=tot+1;j=1;} for(;j<=m;j++,i++) { s[i]=ch; cr=nxt[i-1]; while(cr&&s[cr+1]!=ch)cr=nxt[cr]; if(s[cr+1]==ch)nxt[i]=cr+1; else nxt[i]=0; c[y][j-1]=ch; nt[y][j-1]=nxt[i]; sm[y]+=nxt[i]; } } void solve() { int op,x; char ch[5]; for(int i=1;i<=n;i++) { op=rdn();x=rdn(); if(op==1) { scanf("%s",ch); add(i-1,i,x,ch[0]-'a'+1);} else add(x,i); printf("%lld\n",sm[i]); } } } namespace S2{ const int N=1e5+5,M=2e6+5; int rt[N],tot,Ls[M],Rs[M],cd[N]; ll sm[N]; struct Node{ int c,nxt;}a[M]; int ins(int l,int r,int &cr,int pr,int p,int ch) { if(!cr){cr=++tot;ls=Ls[pr];rs=Rs[pr];} if(l==r){a[cr].c=ch;return cr;} int mid=l+r>>1; if(p<=mid)return ins(l,mid,ls,Ls[pr],p,ch); return ins(mid+1,r,rs,Rs[pr],p,ch); } Node qry(int l,int r,int cr,int p) { if(l==r)return a[cr]; int mid=l+r>>1; if(p<=mid)return qry(l,mid,ls,p); return qry(mid+1,r,rs,p); } void add(int cr,int pr,int m,int ch) { sm[cr]=sm[pr]; cd[cr]=cd[pr]; for(int i=1,d;i<=m;i++) { cd[cr]++; d=ins(1,n,rt[cr],rt[pr],cd[cr],ch); int p=qry(1,n,rt[cr],cd[cr]-1).nxt; while(p&&qry(1,n,rt[cr],p+1).c!=ch) p=qry(1,n,rt[cr],p).nxt; if(p+1!=cd[cr]&&qry(1,n,rt[cr],p+1).c==ch)//!= a[d].nxt=p+1; else a[d].nxt=0; sm[cr]+=a[d].nxt; } } void solve() { int op,x; char ch[5]; for(int i=1;i<=n;i++) { op=rdn();x=rdn(); if(op==1) { scanf("%s",ch); add(i,i-1,x,ch[0]-'a'+1);} else {sm[i]=sm[x];rt[i]=rt[x];cd[i]=cd[x];} printf("%lld\n",sm[i]); } } } int main() { n=rdn(); if(n<=300){S1::solve();return 0;} if(n<=1e5){S2::solve();return 0;} return 0; }
然后看了题解。
因为有 “加入的字符和上一个不同” 的限制,所以考虑一段 x 的末尾后面能续上 x 的 nxt 数组,只有自己的 nxt 跳到了另一段 y 的末尾,满足 x 和 y 的字符与长度均相同。
那个 nxt 就是把一段看做一个字符、相同看做两段的字符与长度均相同的 nxt 数组。
一边跳 nxt 一边累计答案,方法是记录一个 lst 表示当前段已经有前 lst 个字符贡献过答案;如果遇到 c[ p+1 ] == c[ cr ] ( c[ ] 表示字符, p 表示跳到的 nxt ),那么能匹配上的是当前段的前 min( len[ p+1 ] , len[ cr ] ) 个字符(len 表示段长);其中之前没贡献过答案的就是本次要贡献答案的,贡献是 ( s[ p ] + lst + 1 ) 到 ( s[ p ] + min( len[ p+1 ] , len[ cr ] ) ) 的等差数列求和。然后把 lst 更新成 min( len[ p+1 ] , len[ cr ] ) 。
如果第一段的字符和自己相同,而第一段的长度比自己小(大于等于自己的话,在跳 nxt 的时候已经用等差数列加过了。所以跳 nxt 的 break 条件放在贡献答案之后),那么还可以给答案贡献 ( len[ cr ] - lst ) 倍的 len[ 1 ] 。(注意是 ( len[ cr ] - lst ) 而不是 ( len[ cr ] - len[ 1 ] ) )并且这种情况的 nxt[ cr ] 应该等于 1 而不是 0 。
把询问离线,在树上用全局变量维护当前的 c[ ] 和 nxt[ ] , dfs 一遍即可。这样复杂度不对,但可过。目前只写了这样。
#include#include #include #define ll long long using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } int Mn(int a,int b){return aa:b;} int Mx(int a,int b){return a>b?a:b;} const int N=1e5+5,mod=998244353; int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;} int n,c[N],len[N],tc[N],tl[N],s[N],nt[N]; int hd[N],xnt,to[N],nxt[N],ans[N]; int cz(int l,int r) { if(l>r)return 0; return (ll)(l+r)*(r-l+1)/2%mod; } void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;} void dfs(int cr,int pr,int cd) { ans[cr]=pr; if(len[cr]) { cd++; nt[cd]=0;///// tc[cd]=c[cr]; tl[cd]=len[cr]; s[cd]=s[cd-1]+len[cr]; if(cd==1) { ans[cr]=cz(0,len[cr]-1); nt[cd]=0; } else { int p=nt[cd-1],lst=0; while(1) { if(tc[p+1]==c[cr]) { int tp=Mn(tl[p+1],len[cr]); if(tp>lst) { ans[cr]=upt(ans[cr]+cz(s[p]+lst+1,s[p]+tp)); lst=tp;} } if(!p||(tc[p+1]==c[cr]&&tl[p+1]==len[cr]))break; p=nt[p]; } if(tc[p+1]==c[cr]&&tl[p+1]==len[cr]) nt[cd]=p+1; else if(!p&&tc[1]==c[cr]&&tl[1]<len[cr]) ans[cr]=(ans[cr]+(ll)tl[1]*(len[cr]-lst))%mod,nt[cd]=1; //-lst not len[1]//nxt=1 not 0 } } for(int i=hd[cr];i;i=nxt[i]) dfs(to[i],ans[cr],cd); } int main() { n=rdn(); char ch[5]; for(int i=1;i<=n;i++) { int op=rdn(); if(op==2){ int d=rdn();add(d,i);continue;} len[i]=rdn(); scanf("%s",ch); c[i]=ch[0]-'a'+1; add(i-1,i); } dfs(0,0,0); for(int i=1;i<=n;i++)printf("%d\n",ans[i]); return 0; }
然后参考这里的题解(和代码):https://www.cnblogs.com/zhoushuyu/p/10680094.html
复杂度不对的原因是暴力跳 nxt 。可以建 “kmp自动机” ,就是 pr[ i ][ j ] 表示 i 位置后面接 j 字符的话 nxt 会跳到哪个位置。新的位置 i 继承它的 nxt 的 pr[ ][ ] ,i-1 的某个 pr[ ][ ] 值改为 i 。
根据接上来的长度不同,即使字符一样, nxt 仍可能跳到不同的位置。所以每个位置开 26 个主席树维护接上各种字符的不同长度, nxt 会跳到哪个位置。
边跳还要边统计答案。把这个信息也放到主席树上。
答案由两部分构成。设当前段能匹配的长度为 len , 一部分答案是 1 ~ len 的等差数列求和,另一部分是 1 ~ len 对应的 nxt 位置的前缀长度求和。
考虑已经做完当前段,让它给上一个位置的主席树一些更新。设当前段长为 cd , prs 表示到上一个位置为止的前缀段长。
考虑原来的暴力,跳到一个字符相同的位置,可以给当前段的一个前缀的每个位置提供一种可能的贡献,这里需要把 1~cd 位置的 “可能贡献” 改成当前的 prs 。这样一定最优。
所以把主席树上 1~cd 位置的值都改成 prs 。把 cd 位置的 nxt 改成当前段。求答案的时候,假设要匹配的段的长度是 cd2 ,那么它的 nxt 就是主席树 cd2 处记录的 nxt ,它的过程中答案就是主席树 1~cd2 位置的值的和。
注意处理与第一段匹配的情况。需要记录 “当前段最长能匹配多长” 。这个顺便记录即可。就是每次要修改的时候,对应值都可以对当前段长 cd 取 max 。
代码里 rt[ top ][ ch ] 表示 “通过 ch 的边进入 top 之后的种种可能” 。所以往下走的时候,是把 rt[ pr+1 ] 赋值给 rt[ top+1 ] ,用的就是 “通过当前字符从 pr 进入 pr+1 ” 的信息。(pr 表示当前位置的 nxt )
注意主席树新开节点的时候把原来的信息搬过来。
注意在外面枚举 0 点的出边,进入的时候把 rt[ 0 ][ ] 之类的改成初值。
#include#include #include #include #define ll long long #define ls Ls[cr] #define rs Rs[cr] using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } int Mx(int a,int b){return a>b?a:b;} int Mn(int a,int b){return aa:b;} const int N=1e5+5,M=5e6+5,K=30,mod=998244353; int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;} int n,m,cnt,hd[N],xnt,to[N],nxt[N],w[N],c[N],ps[N],ans[N]; int rt[N][K],mxl[N][K],prs[N],top,tc,tl; int tot,Ls[M],Rs[M],sm[M],nt[M],tg[M],tim,dfn[M]; void add(int x,int y,int cd,int ch) {to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;w[xnt]=cd;c[xnt]=ch;} int cal(int x){return (ll)(1+x)*x/2%mod;} int nwnd(int cr) { if(dfn[cr]==tim)return cr; tot++; dfn[tot]=tim; Ls[tot]=ls; Rs[tot]=rs; sm[tot]=sm[cr]; nt[tot]=nt[cr]; tg[tot]=tg[cr];return tot; } void cz(int &cr,int len,int k){ cr=nwnd(cr); sm[cr]=(ll)len*k%mod;} void pshd(int cr,int l,int mid,int r) { if(!tg[cr])return; int k=tg[cr]; tg[cr]=0; cz(ls,mid-l+1,k); cz(rs,r-mid,k); tg[ls]=tg[rs]=k; } void mdfy(int l,int r,int &cr,int R,int k,int p) { if(!cr||dfn[cr]!=tim)cr=nwnd(cr);// if(r 1,k);tg[cr]=k;return;} if(l==r){cz(cr,r-l+1,k);nt[cr]=p;return;} int mid=l+r>>1; pshd(cr,l,mid,r); mdfy(l,mid,ls,R,k,p); if(mid 1,r,rs,R,k,p); sm[cr]=upt(sm[ls]+sm[rs]); } int qry(int l,int r,int cr,int R,int &p) { if(!cr)return 0; if(r return sm[cr]; if(l==r){p=nt[cr]; return sm[cr];} int mid=l+r>>1; pshd(cr,l,mid,r); int ret=qry(l,mid,ls,R,p); if(mid 1,r,rs,R,p)); return ret; } void dfs(int cr,int cd,int ch) { prs[++top]=prs[top-1]+cd; int pr=0; if(top==1)ans[cr]=upt(ans[cr]+cal(cd-1)),tc=ch,tl=cd; else { ans[cr]=upt(ans[cr]+qry(1,m,rt[top][ch],cd,pr)); ans[cr]=upt(ans[cr]+cal(Mn(cd,mxl[top][ch])));//Mn!! if(!pr&&tc==ch&&tl<cd) { if(cd>mxl[top][ch]) ans[cr]=(ans[cr]+(ll)tl*(cd-mxl[top][ch]))%mod; pr=1;/////// } } mxl[top][ch]=Mx(mxl[top][ch],cd); tim++; mdfy(1,m,rt[top][ch],cd,prs[top-1],top); for(int i=hd[cr];i;i=nxt[i]) { memcpy(rt[top+1],rt[pr+1],sizeof rt[pr+1]);//pr+1!!! memcpy(mxl[top+1],mxl[pr+1],sizeof mxl[pr+1]); ans[to[i]]=ans[cr]; dfs(to[i],w[i],c[i]); } top--; } int main() { n=rdn(); char ch; for(int i=1;i<=n;i++) { int op=rdn(), x=rdn(); if(op==1) { cin>>ch; ps[i]=++cnt; m=Mx(m,x); add(ps[i-1],ps[i],x,ch-'a'+1); } else ps[i]=ps[x]; } for(int i=hd[0];i;i=nxt[i]) { memset(rt[1],0,sizeof rt[1]);///// memset(mxl[1],0,sizeof mxl[1]); dfs(to[i],w[i],c[i]); } for(int i=1;i<=n;i++)printf("%d\n",ans[ps[i]]); return 0; }