【bzoj3771】Triple FFT

a表示一个的方案数
b表示取两个相同的
c表示取三个相同的
最终,取一个的是a
取两个的是(a*a-b)/2
取三个的是(a*a*a-3*a*b+2*z)/6

a*a*a用FFT算就可以了

乘法是序列的卷积


#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<algorithm>
#include<iostream>
#define maxn 200100 
#define pi acos(-1)

using namespace std;

struct yts
{
	double r,i;
	yts operator+(yts x) {yts ans;ans.r=r+x.r;ans.i=i+x.i;return ans;}
	yts operator-(yts x) {yts ans;ans.r=r-x.r;ans.i=i-x.i;return ans;}
	yts operator*(yts x) {yts ans;ans.r=r*x.r-i*x.i;ans.i=r*x.i+i*x.r;return ans;}
}a[maxn],b[maxn],c[maxn],d[maxn],temp[maxn];

long long ans[3][maxn];
int n,m,mx,digit;
int seq[maxn];

void FFT(yts x[],int n,int type)
{
	if (n==1) return;
	for (int i=0;i<n;i+=2) temp[i>>1]=x[i],temp[i+n>>1]=x[i+1];
	memcpy(x,temp,sizeof(yts)*n);
	yts *l=x,*r=l+(n>>1);
	FFT(l,n>>1,type);FFT(r,n>>1,type);
	yts root,w;
	root.r=cos(2*type*pi/n),root.i=sin(2*type*pi/n);
	w.r=1;w.i=0;
	for (int i=0;i<(n>>1);i++,w=w*root)
	  temp[i]=l[i]+w*r[i],temp[(n>>1)+i]=l[i]-w*r[i];
	memcpy(x,temp,sizeof(yts)*n);
}

int main()
{
	scanf("%d",&n);
	for (int i=1;i<=n;i++)
	{
		int x;
		scanf("%d",&x);
		seq[i]=x;
		a[x].r++;ans[0][x]++;
		mx=max(mx,x);
	}
	mx*=3;
	for (digit=1;digit<mx;digit<<=1);
	FFT(a,digit,1);
	for (int i=0;i<=digit;i++) b[i]=a[i]*a[i],c[i]=a[i]*a[i]*a[i];
	FFT(b,digit,-1);FFT(c,digit,-1);
	for (int i=1;i<=n;i++) d[2*seq[i]].r++;
	FFT(d,digit,1);
	for (int i=0;i<=digit;i++) d[i]=d[i]*a[i];
	FFT(d,digit,-1);
	for (int i=0;i<=digit;i++)
	{
		ans[1][i]=(long long)(b[i].r/digit+0.5);
		ans[2][i]=(long long)(c[i].r/digit+0.5)-3*(long long)(d[i].r/digit+0.5);
	}
	for (int i=1;i<=n;i++) ans[1][2*seq[i]]--,ans[2][3*seq[i]]+=2;
	for (int i=0;i<=digit;i++) ans[1][i]/=2,ans[2][i]/=6;
	for (int i=0;i<=digit;i++)
	  if (ans[0][i]+ans[1][i]+ans[2][i]>0) printf("%d %lld\n",i,ans[0][i]+ans[1][i]+ans[2][i]);
	return 0;
}


你可能感兴趣的:(【bzoj3771】Triple FFT)