一道挺不错的题目。关键是想到卷积(相信大神看到这就会做了,不对,大神还需要来看我的博客吗)。
首先我们可以求出所有的回文子序列,然后减去回文子串的数量,就可以得到答案了。回文子串的数量可以用manacher算法O(N)得到,那么就看怎么得到回文子序列了。
不妨来看以一个点i为中心有多少回文子序列(以夹缝为中心的回文子序列同理)。我们发现关键是统计有多少个k,满足ch[i-k]=ch[i+k],那么以i为中心的点回文子序列的个数为2^k-1(自己不能作为回文子序列)。另外我们发现不同的字母得到的k是相互独立的。因此我们求出有多少ch[i-k]=ch[i+k]=a,以及有多少ch[i-k]=ch[i+k]=b,然后把两个k加起来就行了。
那么考虑字母'a'带来的影响。可以设一个数组a[i],当ch[i]='a'时a[i]=1。那么对于点i就是统计Σa[i-k]*a[i+k],如果这样还不是很明显,那么我们令点i的答案为2i(i和i+1的夹缝为2i+1),可以看到b[2i]=Σa[i-k]*a[i+k]=Σ(j=0,i)a[j]*a[2i-j]!这就是一个卷积的形式即:b[i]=Σa[j]a[i-j],可以看到对于夹缝这个也是同样成立的!
然后用FFT加速卷积计算即可。注意可以不需要求出'a'和'b'的两个卷积,而可以求出两个点值表达式后合并到一起,然后就只需要求一次差值即可。
AC代码如下:
#include<iostream> #include<cstdio> #include<cmath> #include<cstring> #define pi acos(-1.0) #define mod 1000000007 #define N 300005 using namespace std; int n,m,pos[N],f[N],bin[N]; char ch[N],s[N]; struct cpx{ double r,i; }a[N],b[N]; cpx operator +(cpx x,cpx y){ x.r+=y.r; x.i+=y.i; return x; } cpx operator -(cpx x,cpx y){ x.r-=y.r; x.i-=y.i; return x; } cpx operator *(cpx x,cpx y){ cpx z; z.r=x.r*y.r-x.i*y.i; z.i=x.r*y.i+x.i*y.r; return z; } void dft(cpx *a,int p){ int i,j,k,mid; cpx w,wn,u,v; for (k=2; k<=m; k<<=1){ wn.r=cos(pi*2.0/k*p); wn.i=sin(pi*2.0/k*p); mid=k>>1; for (i=0; i<m; i+=k){ w.r=1; w.i=0; for (j=i; j<i+mid; j++){ u=a[j]; v=a[j+mid]*w; a[j]=u+v; a[j+mid]=u-v; w=w*wn; } } } if (p<0) for (i=0; i<m; i++) a[i].r/=m; } int main(){ scanf("%s",ch+1); n=strlen(ch+1); int i,j,k,cnt=0; m=n<<1|1; for (i=1; i<m; i<<=1) cnt++; m=i; for (i=0; i<m; i++) for (k=i,j=cnt; j; j--,k>>=1) pos[i]=pos[i]<<1|(k&1); for (i=1; i<=n; i++) if (ch[i]=='a') a[pos[i]].r=1; else b[pos[i]].r=1; dft(a,1); dft(b,1); for (i=0; i<m; i++) b[i]=a[i]*a[i]+b[i]*b[i]; for (i=0; i<m; i++) a[pos[i]]=b[i]; dft(a,-1); int ans=0; bin[0]=1; for (i=1; i<=n; i++) bin[i]=(bin[i-1]<<1)%mod; for (i=0; i<m; i++) ans=(ans+bin[((int)(a[i].r+0.5)+1)>>1]-1)%mod; for (i=1; i<=n; i++){ s[i<<1]=ch[i]; s[i<<1|1]='#'; } int len=n<<1|1,mx=0; s[1]='#'; s[0]='$'; s[len+1]='@'; for (i=2; i<len; i++){ f[i]=(mx>i)?min(mx-i,f[(k<<1)-i]):1; while (s[i-f[i]]==s[i+f[i]]) f[i]++; if (i+f[i]>mx){ mx=i+f[i]; k=i; } ans=(ans-(f[i]>>1)+mod)%mod; } printf("%d\n",ans); return 0; }
by lych
2016.3.9