题目链接
题意:
给你一个字符串,问你有多少个子串,形式是 A A B B AABB AABB,其中 A A A和 B B B可以相同。串长 < = 30000 <=30000 <=30000,数据组数 < = 10 <=10 <=10。
题解:
一个早就听说有95分暴力的题。然后现在发现正解其实并不是特别好想吧。我又靠题解度日了。
首先我们可以想到,我们并不需要check一个串是不是满足那个形式,我们只需要对于每一个位置 i i i,求出从 i i i向前有多少个形式是 A A AA AA的串,向后有多少个长度是 A A AA AA的串,然后我们把一个相邻的位置的前面一个位置向前的和后面一个位置向后的答案乘起来,就是以这两个位置为中心的串的个数。有点拗口,用式子表示一下,我们设 x [ i ] x[i] x[i]为从 i i i这个位置向前有多少个形式是 A A AA AA的串,设 y [ i ] y[i] y[i]表示从 i i i开始向后有多少个形式是 A A AA AA的串,那么答案就是 ∑ i = 1 n − 1 x [ i ] ∗ y [ i + 1 ] \sum_{i=1}^{n-1}x[i]*y[i+1] ∑i=1n−1x[i]∗y[i+1]。
那么接下来我们的问题就是如何求一每个位置有多少个向前和向后形式是 A A AA AA的串。我们的做法是枚举形如 A A AA AA的串中一个 A A A的长度 l e n len len,然后我们每隔 l e n len len这么长设一个断点,那么一个合法的串 A A AA AA应该是覆盖正好两个断点的。那么我们对于两个相邻的断点,考虑统计答案。我们设两个断点分别是 i i i和 j j j,其中 j = i + l e n j=i+len j=i+len。我们求出从 i i i和 j j j开始向前向后的最长公共后缀和最长公共前缀,这样求出的就是从 i i i向前向后拓展出的最长的 A A A,从 j j j也可以同样向前向后拓展出这样一个 A A A,拼起来就成了我们要的 A A AA AA,如果这个 A A A的长度大于等于 l e n len len,那么就出现了可行的答案。由于要求正好跨过两个断点,所以这个长度还要与 l e n len len取个min。我们会发现可以作为 A A AA AA的起点的位置是左边连续的一段,可以作为终点的位置是右边连续的一段,我们再写一个区间加复杂度就变高了,由于我们只需要在最后查询,所以我们差分一下,在左端点 + 1 +1 +1,在右端点 − 1 -1 −1就可以了。其中求最长公共前缀和最长公共后缀的时候我们是先用后缀数组+RMQ预处理出来的。这样我们最后再对与起点和终点的合法串各做一下前缀和就可以了。最后再根据一开始那个相乘的式子就可以求出答案。
根据调和级数可以知道,枚举长度每次跳那么长的复杂度是 O ( n l o g n ) O(nlogn) O(nlogn)的,我这里可能写的不是很优秀,你要是预处理一个2的次幂和每一个数的log2的答案就可以做到总复杂度 O ( T n l o g n ) O(Tnlogn) O(Tnlogn)了。多组询问别忘记清空数组。
代码:
#include
using namespace std;
int T,sa[30010][2],rk[30010][2],he[30010][18][2],n;
int s[30010],b[30010],c[30010];
long long ans,cnt1[30010],cnt2[30010];
char a[30010];
inline void get_sa(int id)
{
for(int i=1;i<=n;++i)
rk[i][id]=a[i];
memset(s,0,sizeof(s));
for(int i=1;i<=n;++i)
s[rk[i][id]]++;
for(int i=1;i<=255;++i)
s[i]+=s[i-1];
for(int i=n;i>=1;--i)
sa[s[rk[i][id]]--][id]=i;
int len=1,p=0,x=255;
while(p<n)
{
int k=0;
for(int i=n-len+1;i<=n;++i)
b[++k]=i;
for(int i=1;i<=n;++i)
{
if(sa[i][id]>len)
b[++k]=sa[i][id]-len;
}
for(int i=1;i<=n;++i)
c[i]=rk[b[i]][id];
memset(s,0,sizeof(s));
for(int i=1;i<=n;++i)
s[c[i]]++;
for(int i=1;i<=x;++i)
s[i]+=s[i-1];
for(int i=n;i>=1;--i)
sa[s[c[i]]--][id]=b[i];
for(int i=1;i<=n;++i)
c[i]=rk[i][id];
p=1;
rk[sa[1][id]][id]=1;
for(int i=2;i<=n;++i)
{
if(!(c[sa[i][id]]==c[sa[i-1][id]]&&c[sa[i][id]+len]==c[sa[i-1][id]+len]))
++p;
rk[sa[i][id]][id]=p;
}
len*=2;
x=p;
}
x=0;
for(int i=1;i<=n;++i)
{
if(rk[i][id]==1)
continue;
if(x)
--x;
int j=sa[rk[i][id]-1][id];
while(a[i+x]==a[j+x])
++x;
he[rk[i][id]][0][id]=x;
}
for(int j=1;(1<<j)<=n;++j)
{
for(int i=1;i<=n-(1<<j)+1;++i)
he[i][j][id]=min(he[i][j-1][id],he[i+(1<<(j-1))][j-1][id]);
}
}
inline int lcp(int x,int y,int id)
{
if(x>y)
swap(x,y);
int z=log2(y-x);
return min(he[x+1][z][id],he[y-(1<<z)+1][z][id]);
}
int main()
{
scanf("%d",&T);
while(T--)
{
ans=0;
memset(cnt1,0,sizeof(cnt1));
memset(cnt2,0,sizeof(cnt2));
memset(a,0,sizeof(a));
memset(sa,0,sizeof(sa));
memset(he,0,sizeof(he));
memset(rk,0,sizeof(rk));
memset(b,0,sizeof(b));
memset(c,0,sizeof(c));
scanf("%s",a+1);
n=strlen(a+1);
get_sa(0);
reverse(a+1,a+n+1);
get_sa(1);
for(int len=1;len<=n/2;++len)
{
for(int i=len,j=len<<1;j<=n;i+=len,j+=len)
{
int x=min(lcp(rk[i][0],rk[j][0],0),len),y=min(lcp(rk[n-(i-1)+1][1],rk[n-(j-1)+1][1],1),len-1);
int ji=x+y-len+1;
if(x+y>=len)
{
cnt1[i-y]++;
cnt1[i-y+ji]--;
cnt2[j+x-ji]++;
cnt2[j+x]--;
}
}
}
for(int i=1;i<=n;++i)
{
cnt1[i]+=cnt1[i-1];
cnt2[i]+=cnt2[i-1];
}
for(int i=1;i<=n-1;++i)
ans+=cnt2[i]*cnt1[i+1];
printf("%lld\n",ans);
}
return 0;
}