最后直接计算就可以了
#include<cstdio> #include<cstring> #include<cstdlib> #include<cmath> #include<algorithm> #include<iostream> #define maxn 500010 #define pi acos(-1) #define mod 1000000007 using namespace std; struct yts { double r,i; yts operator+(yts x) {yts ans;ans.r=r+x.r;ans.i=i+x.i;return ans;} yts operator-(yts x) {yts ans;ans.r=r-x.r;ans.i=i-x.i;return ans;} yts operator*(yts x) {yts ans;ans.r=r*x.r-i*x.i;ans.i=r*x.i+i*x.r;return ans;} }a[maxn],b[maxn],temp[maxn]; int n,m,digit; long long ans=0; long long f[maxn]; char s[maxn],s1[maxn]; int p[maxn]; long long Pow[maxn]; void FFT(yts x[],int n,int type) { if (n==1) return; for (int i=0;i<n;i+=2) temp[i>>1]=x[i],temp[n+i>>1]=x[i+1]; memcpy(x,temp,sizeof(yts)*n); yts *l=x,*r=x+(n>>1); FFT(l,n>>1,type);FFT(r,n>>1,type); yts root,w; root.r=cos(2*pi*type/n);root.i=sin(2*pi*type/n); w.r=1;w.i=0; for (int i=0;i<(n>>1);i++,w=w*root) temp[i]=l[i]+w*r[i],temp[(n>>1)+i]=l[i]-w*r[i]; memcpy(x,temp,sizeof(yts)*n); } long long manacher() { for (int i=0;i<n;i++) s1[2*i+1]='#',s1[2*i+2]=s[i]; s1[0]='-';s1[2*n+1]='#';s1[2*n+2]='+';n<<=1; int id=0,mx=0; long long ans=0; for (int i=1;i<=n;i++) { if (mx>i) p[i]=min(mx-i,p[2*id-i]); else p[i]=1; while (s1[i+p[i]]==s1[i-p[i]]) p[i]++; if (i+p[i]>mx) id=i,mx=i+p[i]; ans=(ans+p[i]/2)%mod; } return ans; } int main() { scanf("%s",s); n=strlen(s); for (digit=1;digit<(n<<1);digit<<=1); for (int i=0;i<n;i++) if (s[i]=='a') a[i].r=1; FFT(a,digit,1); for (int i=0;i<digit;i++) b[i]=b[i]+a[i]*a[i]; memset(a,0,sizeof(a)); for (int i=0;i<n;i++) if (s[i]=='b') a[i].r=1; FFT(a,digit,1); for (int i=0;i<digit;i++) b[i]=b[i]+a[i]*a[i]; FFT(b,digit,-1); for (int i=0;i<digit;i++) f[i]=(long long)(b[i].r+0.5)/digit; Pow[0]=1; for (int i=1;i<=n;i++) Pow[i]=Pow[i-1]*2%mod; for (int i=0;i<digit;i++) ans=(ans+Pow[(f[i]+1)>>1]-1)%mod; ans=(ans-manacher()+mod)%mod; printf("%lld\n",ans); return 0; }