首先把两个字符串拼在一起,中间夹一个不可能出现的字符。
然后就是一个简单容斥,我们假设给的字符串为 \(S_1\) 和 \(S_2\),新拼成的字符串为 \(S\),那么答案就是求 \(same(S)-same(S_1)-same(S_2)\),其中 \(same(s)\) 表示 \(s\) 这个字符串中位置不同大小相同的子串的个数。因为容易看出 \(same(S)\) 中统计的要么就是两个字符串都在 \(S_1\),要么就是两个字符串都在 \(S_2\),要么分别在 \(S_1\) 和 \(S_2\) 中(而后者就是本题的答案)。
现在就是考虑给你一个字符串,如何求出它的 \(same(s)\):
\[same(s)=\sum\limits_{i=1}^{n}\sum\limits_{j=i+1}^{n}lcp(Suf_i,Suf_j) \]
原理还是那句话:后缀的前缀可以不重不漏的表示每一个子串。继续:
\[\begin{aligned} same(s)&=\sum\limits_{i=1}^{n}\sum\limits_{j=i+1}^{n}lcp(Suf_i,Suf_j) \\ &=\sum\limits_{i=1}^{n}\sum\limits_{j=i+1}^{n}lcp(Suf_{sa[i]},Suf_{sa[j]}) \\ &=\sum\limits_{i=2}^{n}\sum\limits_{j=i}^{n}\min\limits_{k=i}^{j}height_k \end{aligned} \]
第二行相当于换了个顺序,没什么好解释的,第三行是根据 \(height\) 数组的性质决定的:
\[lcp(i,j)\leq lcp(i,k) \]
其中 \(x_i
这也是我们求 \(lcp\) 是可以用 \(ST\) 表的原理。
回到最后那个式子:可以用单调栈维护左边第一个比当前 \(height_j\) 小的 \(height_k\),然后直接转移就好(具体看代码非常好理解)。
代码:
#include
#include
#include
#include
using namespace std;
typedef long long LL;
const int N=1000000;
char s1[N],s2[N];
int stk[N];
struct Suffix_Array
{
int n,m,h[N],height[N],c[N],x[N],y[N],sa[N];
LL Ans[N];
char s[N];
void clear()
{
memset(h,0,sizeof(h)),memset(c,0,sizeof(c)),memset(x,0,sizeof(x));
memset(y,0,sizeof(y)),memset(sa,0,sizeof(sa)),memset(stk,0,sizeof(stk));
}
void Rsort()
{
for (int i=1;i<=m;i++) c[i]=0;
for (int i=1;i<=n;i++) c[x[y[i]]]++;
for (int i=1;i<=m;i++) c[i]+=c[i-1];
for (int i=n;i>=1;i--) sa[c[x[y[i]]]--]=y[i];
}
void Get_SA()
{
clear();
m=122;
for (int i=1;i<=n;i++)
x[i]=s[i],y[i]=i;
Rsort();
for (int k=1;k<=n;k<<=1)
{
int num=0;
for (int i=n-k+1;i<=n;i++)
y[++num]=i;
for (int i=1;i<=n;i++)
if(sa[i]>k)
y[++num]=sa[i]-k;
Rsort(),swap(x,y);
x[sa[1]]=num=1;
for (int i=2;i<=n;i++)
x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k])?num:++num;
if(num==n) break;
m=num;
}
}
void Get_Height()
{
for (int i=1;i<=n;i++)
{
int tmp=max(0,h[i-1]-1),j,k;
for (j=i+tmp,k=sa[x[i]-1]+tmp;j<=n&&k<=n&&s[j]==s[k];j++,k++);
h[i]=j-i;
}
for (int i=1;i<=n;i++)
height[i]=h[sa[i]];
}
void init()
{
scanf("%s",s+1);
n=strlen(s+1);
}
void work()
{
Get_SA();
Get_Height();
}
LL Get_Ans()
{
memset(Ans,0,sizeof(Ans));
LL ans=0;
int l=1,r=0;
for (int i=2;i<=n;i++)
{
while(l<=r&&height[i]<=height[stk[r]])
r--;
if(r==0) Ans[i]=height[i]*(i-1);
else Ans[i]=height[i]*(i-stk[r])+Ans[stk[r]];
ans+=Ans[i];
stk[++r]=i;
}
return ans;
}
}A,B,C;
void init()
{
A.init();
B.init();
C.n=A.n+B.n+1;
for (int i=1;i<=A.n;i++)
C.s[i]=A.s[i];
C.s[A.n+1]='?';
for (int i=1,j=A.n+2;i<=B.n;i++,j++)
C.s[j]=B.s[i];
}
void work()
{
A.work(),B.work(),C.work();
// printf("%lld %lld %lld\n",A.Get_Ans(),B.Get_Ans(),C.Get_Ans());
printf("%lld\n",C.Get_Ans()-A.Get_Ans()-B.Get_Ans());
}
int main()
{
init();
work();
return 0;
}