FFT:BZOJ4503 两个串

题目描述:戳这里

题解:

如果没有"?",那么我们可以用kmp。
我们可以把这道题目抽象成一个和式:
假设两串S,T分别是0~n,0~m,翻转T串(变成m~0)。
假设T串中"?"的位置都设为0。
假设S串从第x个位置开始匹配可以匹配完T串,那么等价于要满足:
∑ 0 m ( S x + i − T m − i ) 2 T m − i = 0 \sum_0^m(S_{x+i}-T_{m-i})^2T_{m-i}=0 0m(Sx+iTmi)2Tmi=0
化简一下:
∑ 0 m S x + i 2 T m − i − 2 S x + i T m − i + T m − i 3 \sum_0^mS_{x+i}^2T_{m-i}-2S_{x+i}T_{m-i}+T_{m-i}^3 0mSx+i2Tmi2Sx+iTmi+Tmi3
对于最后一项,可以前缀和,前面两项,都是卷积的形式,可以通过FFT来快速解决。
那么复杂度就是 O ( n l o g n ) O(nlogn) O(nlogn)

代码如下:

#include
#define ll long long
using namespace std;
const int maxn=(1<<18)+5;
const double Pi=acos(-1.0);
int n,m,limn,R[maxn];
char S[maxn],T[maxn];
struct comx{
	double x,y;
	comx(double xx=0,double yy=0){x=xx,y=yy;}
	comx operator +(const comx b){return comx(x+b.x,y+b.y);}
	comx operator -(const comx b){return comx(x-b.x,y-b.y);}
	comx operator *(const comx b){return comx(x*b.x-y*b.y,x*b.y+y*b.x);} 
}a[maxn],b[maxn],c[maxn],d[maxn],w[maxn];
double s;
void pre(){
	int L=0; limn=1; while (limn<=n+m) limn<<=1,L++;
	for (int i=0;i<limn;i++){
		R[i]=((R[i>>1]>>1)|((i&1)<<(L-1)));
		w[i]=comx(cos(2*Pi/limn*i),sin(2*Pi/limn*i));
	}
}
void FFT(comx *a,int lim){
	for (int i=0;i<lim;i++) if (R[i]>i) swap(a[R[i]],a[i]);
	for (int t=lim>>1,d=1;d<lim;d<<=1,t>>=1)
	for (int i=0;i<lim;i+=(d<<1))
	for (int j=0;j<d;j++){
		comx p=w[t*j]*a[i+j+d];
		a[i+j+d]=a[i+j]-p,a[i+j]=a[i+j]+p;
	}
}
void doit(comx *p,comx *q){
	FFT(p,limn); FFT(q,limn);
	for (int i=0;i<limn;i++) p[i]=p[i]*q[i],w[i].y=-w[i].y;
	FFT(p,limn);
	for (int i=0;i<limn;i++) w[i].y=-w[i].y,p[i].x/=limn;
}
ll cal(double x){return (ll)(x+0.5);}
int main(){
	scanf("%s",S); scanf("%s",T);
	n=strlen(S)-1; m=strlen(T)-1;
	for (int i=0;i<=n;i++) c[i].x=S[i]-'a'+1;
	for (int i=0;i<=m;i++)
		if (T[i]!='?') b[m-i].x=T[i]-'a'+1; else b[m-i].x=0;
	for (int i=0;i<=n;i++) a[i].x=c[i].x*c[i].x;
	for (int i=0;i<=m;i++) d[i].x=b[i].x*b[i].x,s+=b[i].x*b[i].x*b[i].x;
	pre(); doit(a,b); doit(c,d);
	int ans=0;
	for (int i=m;i<=n;i++)
		if (cal(a[i].x)-2*cal(c[i].x)+cal(s)==0) ans++;
	printf("%d\n",ans);
	for (int i=m;i<=n;i++)
		if (cal(a[i].x)-2*cal(c[i].x)+cal(s)==0) printf("%d\n",i-m);
	return 0;
}

你可能感兴趣的:(题解,BZOJ题解)