codechef Arithmetic Progressions(分块+FFT)

题目链接:http://www.codechef.com/problems/COUNTARI/

题意:给出一个长度为n的数列A。求有多少三元组(i,j,k)满足i<j<k且A[i]-A[j]=A[j]-A[k]?

思路:将数列分成K块。那么对于每三个数:

(1)在一块中,枚举两个,查找第三个数的个数;

(2)两个数在一块中,另一个数在前面或者后面的块中,也是枚举该块内的两个数,查找第三个数的个数;

(3)三个数在三块中,枚举中间一块中的数,用FFT计算前后两部分。

 

struct node

{

    double x,y;



    node(double _x=0.0,double _y=0.0)

    {

        x=_x;

        y=_y;

    }



    node operator+(node a)

    {

        return node(x+a.x,y+a.y);

    }



    node operator-(node a)

    {

        return node(x-a.x,y-a.y);

    }



    node operator*(node a)

    {

        return node(x*a.x-y*a.y,x*a.y+y*a.x);

    }

};



node A[N];

int L;





int reverse(int x)

{

    int ans=0,i;

    FOR0(i,L) if(x&(1<<i)) ans|=1<<(L-1-i);

    return ans;

}





void bitReverseCopy(node a[],int n)

{

    int i;

    FOR0(i,n) A[i]=a[i];

    FOR0(i,n)

    {

        a[reverse(i)]=A[i];

    }

}





void fft(node a[],int n,int on)

{

    bitReverseCopy(a,n);

    int len,i,j,k;

    node x,y,u,t;

    for(len=2;len<=n;len<<=1)

    {

        x=node(cos(-on*2*PI/len),sin(-on*2*PI/len));

        for(j=0;j<n;j+=len)

        {

            y=node(1,0);

            for(k=j;k<j+len/2;k++)

            {

                u=a[k];

                t=y*a[k+len/2];

                a[k]=u+t;

                a[k+len/2]=u-t;

                y=y*x;

            }

        }

    }

    if(on==-1)

    {

        FOR0(i,n) a[i].x/=n;

    }

}



int a[N],C[N],n;

int size;

i64 ans;

int b[100][2],bNum;



void deal1()

{

    int L,R,i,j,k;

    for(L=1;L<=n;L+=size)

    {

        R=min(n,L+size-1);

        b[bNum][0]=L;

        b[bNum][1]=R;

        bNum++;

        for(i=L;i<=R;i++)

        {

            for(j=i+1;j<=R;j++)

            {

                k=2*a[i]-a[j];

                if(k>=0) ans+=C[k];

            }

            C[a[i]]++;

        }

        for(i=L;i<=R;i++) C[a[i]]--;

    }

}





void deal2()

{

    int L,R,i,j,k,t;

    for(t=0;t<bNum;t++)

    {

        L=b[t][0];

        R=b[t][1];

        for(i=L;i<=R;i++) for(j=i+1;j<=R;j++)

        {

            k=2*a[i]-a[j];

            if(k>=0) ans+=C[k];

        }

        for(i=L;i<=R;i++) C[a[i]]++;

    }

    for(i=1;i<=n;i++) C[a[i]]--;

    for(t=bNum-1;t>=0;t--)

    {

        L=b[t][0];

        R=b[t][1];

        for(i=R;i>=L;i--) for(j=i-1;j>=L;j--)

        {

            k=2*a[i]-a[j];

            if(k>=0) ans+=C[k];

        }

        for(i=L;i<=R;i++) C[a[i]]++;

    }

    for(i=1;i<=n;i++) C[a[i]]--;

}



int pre[N],tail[N];



node P[N],Q[N];

int M;

i64 Ans[N];



void init()

{

    M=1; L=0;

    while(M<=60000) M<<=1,L++;

    int i;

    for(i=0;i<=30000;i++) P[i]=node(pre[i],0),Q[i]=node(tail[i],0);

    while(i<M) P[i]=node(0,0),Q[i]=node(0,0),i++;

    fft(P,M,1);

    fft(Q,M,1);

    for(i=0;i<M;i++) P[i]=P[i]*Q[i];

    fft(P,M,-1);

    for(i=0;i<M;i++) Ans[i]=(i64)(P[i].x+0.5);

}



void deal3()

{

    int i,j,k,t,L,R;



    L=b[0][0]; R=b[0][1];

    for(i=L;i<=R;i++) pre[a[i]]++;

    for(t=1;t<bNum;t++)

    {

        L=b[t][0];

        R=b[t][1];

        for(i=L;i<=R;i++) tail[a[i]]++;

    }

    for(t=1;t<=bNum-2;t++)

    {

        L=b[t][0];

        R=b[t][1];

        for(i=L;i<=R;i++) tail[a[i]]--;

        init();

        for(i=L;i<=R;i++) ans+=Ans[a[i]<<1];

        for(i=L;i<=R;i++) pre[a[i]]++;

    }

}



int main()

{

    RD(n);

    int i;

    FOR1(i,n) RD(a[i]);

    size=n/min(n,35); ans=0;

    deal1();

    deal2();

    deal3();

    printf("%lld\n",ans);

}

  

你可能感兴趣的:(progress)