CodeChef Arithmetic Progressions (分块FFT)

题意:给出 A 1 ,A 2 ,...,A N 统计满足:
▶ i < j < k
▶ A i + A k = 2A j
的 (i,j,k) 数量。

(N ≤ 30000,A i ≤ 10 5 )

分块,把长度为n的序列分成sqrt(n)块长度为sqrt(n)的序列,然后遍历每一个块分三种情况:

1)三个都在同一个块里面:

暴力枚举后两个,每次维护前面的数的个数,复杂度O(sqrt(n)*n)

2)两个在同一块里面:

暴力枚举块中的两个,维护块前的数的个数和块后的数的个数,复杂度O(sqrt(n)*n)

3)一个在块中,一个在块前,一个在块后

块前块后做fft,枚举块中的数,复杂度O(nlgn*sqrt(n)).

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
#define pi acos (-1)
#define maxn 151111

struct plex {
    double x, y;
    plex (double _x = 0.0, double _y = 0.0) : x (_x), y (_y) {}
    plex operator + (const plex &a) const {
        return plex (x+a.x, y+a.y);
    }
    plex operator - (const plex &a) const {
        return plex (x-a.x, y-a.y);
    }
    plex operator * (const plex &a) const {
        return plex (x*a.x-y*a.y, x*a.y+y*a.x);
    }
};

void change (plex y[] , int len) {
    for (int i = 1 , j = len / 2 ; i < len -1 ; i ++) {
        if (i < j) swap(y[i] , y[j]);
        int k = len / 2;
        while (j >= k) {
            j -= k;
            k /= 2;
        }
        if(j < k) j += k;
    }
}

void fft(plex y[],int len,int on)
{
    change(y,len);
    for(int h = 2; h <= len; h <<= 1)
    {
        plex wn(cos(on*2*pi/h),sin(on*2*pi/h));
        for(int j = 0;j < len;j+=h)
        {
            plex w(1,0);
            for(int k = j;k < j+h/2;k++)
            {
                plex u = y[k];
                plex t = w*y[k+h/2];
                y[k] = u+t;
                y[k+h/2] = u-t;
                w = w*wn;
            }
        }
    }
    if(on == -1)
        for(int i = 0;i < len;i++)
            y[i].x /= len;
}
int num[maxn], num2[maxn];
long long sum[maxn];
int a[maxn];
plex x[maxn], y[maxn];
int n;

int main () {
    //freopen ("in.txt", "r", stdin);
    scanf ("%d", &n);
    for (int i = 0; i < n; i++) {
        scanf ("%d", &a[i]);
    }
    long long ans = 0;
    int block = sqrt (n), len = n/block;
    if (len*block != n)
        block++;
    //1 都在当前块
    for (int t = 0; t < block; t++) {
        memset (num, 0, sizeof num);
        for (int i = min ((t+1)*len-1, n-1); i >= t*len; i--) {
            for (int j = t*len; j < i; j++)
                num[a[j]]++;
            for (int j = i-1; j >= t*len; j--) {
                num[a[j]]--;
                int cur = 2*a[j]-a[i];
                if (cur > 0)
                    ans += num[cur];
            }
        }
    }
    //两个在当前块
    for (int t = 0; t < block; t++) {
        memset (num, 0, sizeof num);
        //第三个在前面的块中
        for (int i = 0; i < t*len; i++)
            num[a[i]]++;
        for (int i = t*len; i < n && i < (t+1)*len; i++) {
            for (int j = i+1; j < n && j < (t+1)*len; j++) {
                int cur = a[i]*2-a[j];
                if (cur > 0)
                    ans += num[cur];
            }
        }
        //第三个在后面的块中
        memset (num, 0, sizeof num);
        for (int i = (t+1)*len; i < n; i++)
            num[a[i]]++;
        for (int i = t*len; i < n && i < (t+1)*len; i++) {
            for (int j = i+1; j < n && j < (t+1)*len; j++) {
                int cur = a[j]*2-a[i];
                if (cur > 0)
                    ans += num[cur];
            }
        }
    }
    //只有一个在当前块中 一个在前面 一个在后面
    for (int t = 0; t < block; t++) {
        int cnt1 = 0, cnt2 = 0, Max = 0;
        memset (num, 0, sizeof num);
        for (int i = 0; i < t*len; i++) {
            num[a[i]]++;
            Max = max (Max, a[i]);
        }
        memset (num2, 0, sizeof num2);
        for (int i = (t+1)*len; i < n; i++) {
            num2[a[i]]++;
            Max = max (Max, a[i]);
        }
        int l = 1;
        Max++;
        while (l < 2*Max)
            l <<= 1;
        for (int i = 0; i < l; i++) {
            x[i] = plex (num[i], 0);
        }
        for (int i = 0; i < l; i++) {
            y[i] = plex (num2[i], 0);
        }
        fft (x, l, 1);
        fft (y, l, 1);
        for (int i = 0; i < l; i++)
            x[i] = x[i]*y[i];
        fft (x, l, -1);
        memset (sum, 0, sizeof sum);
        for (int i = 1; i < l; i++) {
            sum[i] = (long long) (x[i].x+0.5);
        }
        for (int i = t*len; i < (t+1)*len && i < n; i++) {
            ans += sum[2*a[i]];
        }
    }
    printf ("%lld\n", ans);
    return 0;
}


你可能感兴趣的:(CodeChef Arithmetic Progressions (分块FFT))