bzoj 3509: [CodeChef] COUNTARI (FFT+分块)

题目描述

传送门

题目大意:给定一个长度为N的数组A[],求有多少对i, j, k(1<=i

题解

对序列进行分块。然后分情况讨论。
(1)i,j,k在同一个块中,从左向右顺序枚举i,j,对于j后面的数字出现情况用cnt[i]数组动态维护,每次计算答案的时候加上 cnt[2a[j]a[i]]
(2)i,j在同一个块中,k在另一个块中或者j,k在同一个块中,i在另一个块中,用上面类似的方式维护。
(3)i,j,k在三个不同的块中,枚举j所在的块,用两个数组动态维护左右两端数字的出现情况,然后用FFT优化,得到 c[x] ,其中 x=a[i]+a[k] ,然后枚举j所在块中的数字,每次加上 c[a[j]2]
块的大小在2000左右比较合适。

代码

#include
#include
#include
#include
#include
#define N 200003
#define LL long long 
#define pi acos(-1)
using namespace std;
struct data{
    double x,y;
    data(double X=0,double Y=0) {
        x=X,y=Y;
    }
}f[N],g[N];
int n,n1,m,cnt[N],a[N],cntl[N],cntr[N],L,R[N],blocksize,l[N],r[N];
LL c[N];
data operator +(data a,data b){
    return data(a.x+b.x,a.y+b.y);
}
data operator -(data a,data b){
    return data(a.x-b.x,a.y-b.y);
}
data operator *(data a,data b){
    return data(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
}
void FFT(data a[],int n,int opt)
{
    for (int i=0;iif (i>R[i]) swap(a[i],a[R[i]]);
    for (int i=1;i1) {
        data wn=data(cos(pi/i),opt*sin(pi/i));
        for (int p=i<<1,j=0;j1,0);
            for (int k=0;kint main()
{
    freopen("a.in","r",stdin);
    freopen("my.out","w",stdout);
    scanf("%d",&n);
    int mx=0;
    for (int i=1;i<=n;i++) scanf("%d",&a[i]),mx=max(mx,a[i]);
    mx*=2;
    for (n1=1;n1<=mx;n1<<=1) L++;
    for (int i=0;i<=n1;i++) R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
    LL ans=0;
    blocksize=min(n,2000);
    int t=(n-1)/blocksize+1;
    for (int i=1;i<=t;i++) {
        l[i]=(i-1)*blocksize+1;
        r[i]=min(l[i]+blocksize-1,n);
    }
    for (int i=1;i<=n;i++) cnt[a[i]]++;
    for (int k=1;k<=t;k++) {
        for (int i=l[k];i<=r[k];i++) cnt[a[i]]--;
        for (int i=l[k];i<=r[k];i++)
         for (int j=l[k];jint v=a[i]-a[j];
            v=a[i]+v;
            if (v>=0)ans+=(LL)cnt[v];
         }
    } 
    for (int i=1;i<=n;i++) cnt[a[i]]++;
    for (int k=t;k>=1;k--) {
        for (int i=l[k];i<=r[k];i++) cnt[a[i]]--;
        for (int i=l[k];i<=r[k];i++)
         for (int j=l[k];jint v=a[i]-a[j];
            v=a[j]-v;
            if (v>=0) ans+=(LL)cnt[v];
         }
    }
    for (int k=1;k<=t;k++) {
        for (int i=l[k-1];i<=r[k-1];i++) cnt[a[i]]=0;
        for (int i=l[k];i<=r[k];i++) cnt[a[i]]++;
        for (int i=l[k];i<=r[k];i++) {
            cnt[a[i]]--;
            for (int j=l[k];jint v=a[i]-a[j];
                v=a[i]+v;
                if (v>=0) ans+=(LL)cnt[v];
            }
        }
    }
    if (t>=3) {
        for (int i=l[1];i<=r[1];i++) cntl[a[i]]++;
        for (int i=l[3];i<=n;i++) cntr[a[i]]++;
        for (int k=2;k<=t-1;k++) {
            memset(f,0,sizeof(f));
            memset(g,0,sizeof(g));
            for (int i=0;i<=mx;i++) f[i].x=cntl[i];
            for (int i=0;i<=mx;i++) g[i].x=cntr[i];
            FFT(f,n1,1); FFT(g,n1,1);
            for (int i=0;i1);
            for (int i=0;i<=mx;i++) c[i]=(LL)(f[i].x/n1+0.5);
            for (int i=l[k];i<=r[k];i++) ans+=c[a[i]*2];
            for (int i=l[k];i<=r[k];i++) cntl[a[i]]++;
            for (int i=l[k+1];i<=r[k+1];i++) cntr[a[i]]--;
        }
    }
    printf("%lld\n",ans);
}

你可能感兴趣的:(FFT)