bzoj3160 万径人踪灭 FFT+manacher

       一道挺不错的题目。关键是想到卷积(相信大神看到这就会做了,不对,大神还需要来看我的博客吗)。

       首先我们可以求出所有的回文子序列,然后减去回文子串的数量,就可以得到答案了。回文子串的数量可以用manacher算法O(N)得到,那么就看怎么得到回文子序列了。

       不妨来看以一个点i为中心有多少回文子序列(以夹缝为中心的回文子序列同理)。我们发现关键是统计有多少个k,满足ch[i-k]=ch[i+k],那么以i为中心的点回文子序列的个数为2^k-1(自己不能作为回文子序列)。另外我们发现不同的字母得到的k是相互独立的。因此我们求出有多少ch[i-k]=ch[i+k]=a,以及有多少ch[i-k]=ch[i+k]=b,然后把两个k加起来就行了。

       那么考虑字母'a'带来的影响。可以设一个数组a[i],当ch[i]='a'时a[i]=1。那么对于点i就是统计Σa[i-k]*a[i+k],如果这样还不是很明显,那么我们令点i的答案为2i(i和i+1的夹缝为2i+1),可以看到b[2i]=Σa[i-k]*a[i+k]=Σ(j=0,i)a[j]*a[2i-j]!这就是一个卷积的形式即:b[i]=Σa[j]a[i-j],可以看到对于夹缝这个也是同样成立的!

       然后用FFT加速卷积计算即可。注意可以不需要求出'a'和'b'的两个卷积,而可以求出两个点值表达式后合并到一起,然后就只需要求一次差值即可。

AC代码如下:

#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#define pi acos(-1.0)
#define mod 1000000007
#define N 300005
using namespace std;

int n,m,pos[N],f[N],bin[N]; char ch[N],s[N];
struct cpx{ double r,i; }a[N],b[N];
cpx operator +(cpx x,cpx y){ x.r+=y.r; x.i+=y.i; return x; }
cpx operator -(cpx x,cpx y){ x.r-=y.r; x.i-=y.i; return x; }
cpx operator *(cpx x,cpx y){
	cpx z; z.r=x.r*y.r-x.i*y.i; z.i=x.r*y.i+x.i*y.r; return z;
}
void dft(cpx *a,int p){
	int i,j,k,mid; cpx w,wn,u,v;
	for (k=2; k<=m; k<<=1){
		wn.r=cos(pi*2.0/k*p); wn.i=sin(pi*2.0/k*p); mid=k>>1;
		for (i=0; i<m; i+=k){
			w.r=1; w.i=0;
			for (j=i; j<i+mid; j++){
				u=a[j]; v=a[j+mid]*w;
				a[j]=u+v; a[j+mid]=u-v; w=w*wn;
			}
		}
	}
	if (p<0) for (i=0; i<m; i++) a[i].r/=m;
}
int main(){
	scanf("%s",ch+1); n=strlen(ch+1); int i,j,k,cnt=0;
	m=n<<1|1; for (i=1; i<m; i<<=1) cnt++; m=i;
	for (i=0; i<m; i++)
		for (k=i,j=cnt; j; j--,k>>=1) pos[i]=pos[i]<<1|(k&1);
	for (i=1; i<=n; i++)
		if (ch[i]=='a') a[pos[i]].r=1; else b[pos[i]].r=1;
	dft(a,1); dft(b,1);
	for (i=0; i<m; i++) b[i]=a[i]*a[i]+b[i]*b[i];
	for (i=0; i<m; i++) a[pos[i]]=b[i];
	dft(a,-1); int ans=0;
	bin[0]=1; for (i=1; i<=n; i++) bin[i]=(bin[i-1]<<1)%mod;
	for (i=0; i<m; i++)
		ans=(ans+bin[((int)(a[i].r+0.5)+1)>>1]-1)%mod;
	for (i=1; i<=n; i++){
		s[i<<1]=ch[i]; s[i<<1|1]='#';
	}
	int len=n<<1|1,mx=0;
	s[1]='#'; s[0]='$';  s[len+1]='@';
	for (i=2; i<len; i++){
		f[i]=(mx>i)?min(mx-i,f[(k<<1)-i]):1;
		while (s[i-f[i]]==s[i+f[i]]) f[i]++;
		if (i+f[i]>mx){ mx=i+f[i]; k=i; }
		ans=(ans-(f[i]>>1)+mod)%mod;
	}
	printf("%d\n",ans);
	return 0;
}


by lych

2016.3.9 

你可能感兴趣的:(字符串,fft,Manacher)