BZOJ 3771 Triple FFT+容斥原理

题意:链接

方法: FFT+容斥原理

解析:

这东西其实就是指数型母函数?

所以刚开始读入的值我们都把它前面的系数置为1。

然后其实就是个多项式乘法了。

最大范围显然是读入的值中的最大值乘三,对于本题的话是12W?

用FFT优化的话,达到了O(nlogn),显然可过。

但是这里有一个问题,就是如何处理重复的部分。

重复的部分我们考虑用容斥原理来解决。

为了方便描述我们不妨设三个多项式。

第一个是仅取一个而构成的多项式。->x

第二个是仅取相同的两个而构成的多项式。->y

第三个是仅取相同的三个而构成的多项式。->z

对于本题有三种情况。

第一种是取一个,显然直接将x加到答案就好。

第二种是取两个,则需要一小步容斥,即(x*x-y)/2

第三种是取三个,则需要进一步容斥,即(x*x*x-3*x*y+2*z)/6

至于第三种,简单说明一下,将x*x*x算上是代表了所有取三个的情况,减去3*x*y其实是减掉一个x*y/2,即选两个相同的再除以排列数,因为大情况是3!重复,故添了个系数,减两个相同的肯定包含三个相同的,所以要加回来两个。

最后统计答案即可。

代码:

#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define N 131072
#define pi acos(-1)
using namespace std;
int n; 
struct complex
{
    double r,i;
    complex(double x=0.0,double y=0.0){r=x,i=y;}
    complex operator + (const complex a)
    {return complex(a.r+r,a.i+i);}
    complex operator - (const complex a) 
    {return complex(r-a.r,i-a.i);}
    complex operator * (const complex a)
    {return complex(r*a.r-i*a.i,r*a.i+i*a.r);}
}a[N+10],b[N+10],c[N+10],d[N+10];
int rev[N+10];
void FFT(complex *a,int f)
{
    for(int i=0;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
    for(int h=2;h<=n;h<<=1)
    {
        complex wn(cos(2*pi*f/h),sin(2*pi*f/h));
        for(int i=0;i<n;i+=h)
        {
            complex w(1,0);
            for(int j=0;j<(h>>1);j++,w=w*wn)
            {
                complex t=a[i+j+(h>>1)]*w;
                a[i+j+(h>>1)]=a[i+j]-t;
                a[i+j]=a[i+j]+t;
            }
        }
    }
    if(f==-1)for(int i=0;i<n;i++)a[i].r/=n;
}
int main()
{
    int ma=-1;
    scanf("%d",&n);n--;
    for(int i=0;i<=n;i++)
    {
        int x;
        scanf("%d",&x);
        a[x].r=1,b[2*x].r=1,c[3*x].r=1;
        ma=max(ma,3*x);
    }
    int m=ma,L=0;
    for(n=1;n<=m;n<<=1)L++;
    for(int i=0;i<n;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
    FFT(a,1),FFT(b,1),FFT(c,1);
    for(int i=0;i<=n;i++)
    {
        complex tmp(1.0/6.0,0);
        complex tmp2(3.0,0);
        complex tmp3(2.0,0);
        complex tmp4(1.0/2.0,0);
        d[i]=d[i]+(a[i]*a[i]*a[i]-tmp2*a[i]*b[i]+tmp3*c[i])*tmp;
        d[i]=d[i]+(a[i]*a[i]-b[i])*tmp4;
        d[i]=d[i]+a[i];
    }
    FFT(d,-1); 
    for(int i=0;i<=n;i++)
    {
        int print=(int)(d[i].r+0.1);
        if(print!=0)printf("%d %d\n",i,print);
    } 
}

你可能感兴趣的:(X,fft)