给定一个长度为 n 的字符串 s ,你需要计算
1≤n≤500000
对 s 构造后缀数组,然后考虑从大到小枚举 LCP 的长度,然后每次答案加上 LCP 长度大于等于该值的子串个数。
一开始后缀数组的每一个位置都是孤立的一个块,随着 LCP 长度的减少我们开始合并相邻的一些块。显然我们只需要记录一个块内的后缀长度和以及后缀个数就可以很方便地计算答案,这个信息也是可以合并的。
时间复杂度 O(nlogn) 。
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
using namespace std;
const int P=998244353;
const int N=500050;
int SA[N],rk[N],Ws[N],Wv[N],x[N],y[N],height[N],fa[N],size[N],sum[N],pos[N],rank[N],p[N];
char s[N];
int n,ans,cur1,cur2,cur3;
int sqr(int x){return 1ll*x*x%P;}
bool comp1(int x,int y){return height[x]<height[y];}
bool comp2(int x,int y){return n-SA[x]<n-SA[y];}
bool cmp(int *a,int x,int y,int l){return x+l>=n||y+l>=n||a[x]!=a[y]||a[x+l]!=a[y+l];}
void DA()
{
int i,j,p,l,mx;
for (i=0,mx=0;i<n;++i) mx=max(mx,Wv[i]=s[i]-'a');
for (i=0;i<=mx;++i) Ws[i]=0;
for (i=0;i<n;++i) ++Ws[Wv[i]];
for (i=1;i<=mx;++i) Ws[i]+=Ws[i-1];
for (i=n-1;i>=0;--i) SA[--Ws[Wv[i]]]=i;
for (x[SA[0]]=p=0,i=1;i<n;++i) x[SA[i]]=p+=s[SA[i]]!=s[SA[i-1]];
for (l=1;l<=n&&p!=n-1;l<<=1)
{
for (p=0,i=n-l;i<n;++i) y[p++]=i;
for (i=0;i<n;++i) if (SA[i]>=l) y[p++]=SA[i]-l;
for (mx=0,i=0;i<n;++i) mx=max(mx,Wv[i]=x[y[i]]);
for (i=0;i<=mx;++i) Ws[i]=0;
for (i=0;i<n;++i) ++Ws[Wv[i]];
for (i=1;i<=mx;++i) Ws[i]+=Ws[i-1];
for (i=n-1;i>=0;--i) SA[--Ws[Wv[i]]]=y[i];
for (i=0;i<n;++i) y[i]=x[i],x[i]=0;
for (x[SA[0]]=p=0,i=1;i<n;++i) x[SA[i]]=p+=cmp(y,SA[i],SA[i-1],l);
}
for (i=0;i<n;++i) rank[SA[i]]=i;
}
void getheight()
{
for (int i=0,h=0;i<n;++i,h=max(0,h-1))
{
if (rank[i]) for (int k=SA[rank[i]-1];i+h<n&&k+h<n&&s[i+h]==s[k+h];++h);
height[rank[i]]=h;
}
}
int getfather(int son){return fa[son]==son?son:fa[son]=getfather(fa[son]);}
void merge(int x,int y)
{
x=getfather(x),y=getfather(y);
if (rank[x]<rank[y]) swap(x,y);
fa[y]=x,rank[x]+=rank[x]==rank[y];
cur1=(cur1-sqr(sum[x])+P)%P,cur2=(cur2-1ll*sum[x]*size[x]%P+P)%P,cur3=(cur3-sqr(size[x])+P)%P;
cur1=(cur1-sqr(sum[y])+P)%P,cur2=(cur2-1ll*sum[y]*size[y]%P+P)%P,cur3=(cur3-sqr(size[y])+P)%P;
(sum[x]+=sum[y])%=P,size[x]+=size[y];
(cur1+=sqr(sum[x]))%=P,(cur2+=1ll*sum[x]*size[x]%P)%=P,(cur3+=sqr(size[x]))%=P;
}
void calc()
{
for (int i=0;i<n;++i) fa[i]=i,rank[i]=0,size[i]=0,sum[i]=0,pos[i]=i,p[i]=i;
sort(pos,pos+n,comp1),sort(p,p+n,comp2);
cur1=cur2=cur3=0;
for (int i=0;i<n;++i) (cur1+=sqr(sum[i]))%=P,(cur2+=1ll*sum[i]*size[i]%P)%=P,(cur3+=sqr(size[i]))%=P;
for (int i=n,cur=n-1,ptr=n-1;i>=1;--i)
{
for (;ptr>=0&&n-SA[p[ptr]]==i;--ptr)
{
int x=getfather(p[ptr]);
cur1=(cur1-sqr(sum[x])+P)%P,cur2=(cur2-1ll*sum[x]*size[x]%P+P)%P,cur3=(cur3-sqr(size[x])+P)%P;
(sum[x]+=n-SA[p[ptr]])%=P,++size[x];
(cur1+=sqr(sum[x]))%=P,(cur2+=1ll*sum[x]*size[x]%P)%=P,(cur3+=sqr(size[x]))%=P;
}
for (;cur>=0&&height[pos[cur]]==i;--cur)
{
int x=pos[cur];
if (x) merge(x-1,x);
}
(ans+=(cur1+1ll*cur3*sqr(i-1)%P-2ll*cur2*(i-1)%P+P)%P)%=P;
}
}
int main()
{
freopen("substring.in","r",stdin),freopen("substring.out","w",stdout);
scanf("%s",s),n=strlen(s);
DA(),getheight(),calc();
printf("%d\n",ans);
fclose(stdin),fclose(stdout);
return 0;
}