【 bzoj 3509 】[CodeChef] COUNTARI - 分块FFT

  看起来很像数据结构乱搞?
  然而仔细看看数据范围: Ai<=30000
  这意味着我们可以用生成函数来乱搞。
  将式子变成 Ai+Ak=2Aj ,于是就可以很愉快地枚举j,记录j两边的生成函数,然后求卷积
  就行了
  ……?
  毛。。。这样卷积是 O(nVlogV) 的   //以下设 V=max{A}
  怎样降低FFT次数呢?
  考虑分块。假设分成了 B 块,块的大小是 L
  还是枚举j,考虑i在块内的情况,可以很显然地预处理对应数出现次数来搞。k在块内的时候同理。注意这里有一些重复的要减去。然后i,k都在块外就用FFT来搞。
  这样时间复杂度是 O(BVlogV+(B+L)V)
  由于上面这个式子的存在,块不能设太小,不然B就会太大,FFT次数会增多,大概 2000 左右就差不多了。
  用了某个方法稍微优化了一下FFT的常数(不过应该是损耗了一点点精度的吧)。
  

#include 
using namespace std;
#define rep(i,a,b) for(int i = a , _ = b ; i <= _ ; i ++)
#define per(i,a,b) for(int i = a , _ = b ; i >= _ ; i --)
#define For(i,a,b) for(int i = a , _ = b ; i <  _ ; i ++)

inline int rd() {
    char c = getchar();
    while (!isdigit(c)) c = getchar() ; int x = c - '0';
    while (isdigit(c = getchar())) x = x * 10 + c - '0';
    return x;
}

inline void upmax(int&a , int b) { if (a < b) a = b; }
inline void upmin(int&a , int b) { if (a > b) a = b; }

typedef long long ll;

const int maxn = 110007;
const int maxs = 61007;
const int maxb = 307;
const int len = 1823;

typedef int arr[maxn];
typedef int blk[maxb];
typedef int num[maxs];

arr a , belong , L , R;
blk st , ed;
num cnt[maxb] , cnt_nxt , cnt_pre;

ll ans;

int n , tot , lim;

void input() {
    n = rd();
    rep (i , 1 , n) a[i] = rd() , upmax(lim , a[i]);
}

inline void init_block() {
    rep (i , 1 , n) belong[i] = (i - 1) / len + 1;
    tot = belong[n];
    rep (i , 1 , tot) st[i] = (i - 1) * len + 1 , ed[i] = i * len;
    upmin(ed[tot] , n);
    rep (i , 1 , tot) {
        rep (j , st[i] , ed[i])
            cnt[i][a[j]] ++;
    }
}

const int maxN = 524299;

struct comp {
    long double real , imag;
    comp(long double real = 0 , long double imag = 0):real(real) , imag(imag) { }
    inline friend comp operator+(comp&a , comp&b)
        { return comp(a.real + b.real , a.imag + b.imag); }
    inline friend comp operator-(comp&a , comp&b)
        { return comp(a.real - b.real , a.imag - b.imag); }
    inline friend comp operator*(comp&a , comp&b)
        { return comp(a.real * b.real - a.imag * b.imag , a.imag * b.real + a.real * b.imag); }
    inline friend void swap(comp&a , comp&b)
        { comp c = a ; a = b ; b = c; }
}A[maxN];

int rev[maxN] , N , Nlen;

ll res[maxN];

inline void FFT_clear() {
    memset(A , 0 , sizeof(comp) * N);
}

inline void FFT_init() {
    for (N = 1 , Nlen = 0;N <= lim + lim;N <<= 1 , Nlen ++) ;
    For (i , 1 , N) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << Nlen - 1);
}

const long double ppi = std::acos(-1.0) * 2;

inline void FFT(comp*a , int n , int v) {
    For (i , 0 , n) if (i < rev[i]) swap(a[i] , a[rev[i]]);
    for (int s = 2;s <= n;s <<= 1) {
        comp g(cos(ppi / s) , v * sin(ppi / s));
        for (int k = 0;k < n;k += s) {
            comp w(1 , 0);
            For (j , 0 , s / 2) {
                comp t = a[k + j + s / 2] * w , u = a[k + j];
                a[k + j + s / 2] = u - t , a[k + j] = u + t;
                w = w * g;
            }
        }
    }
    if (v == -1) For (i , 0 , n)
        a[i].real /= n * 4, a[i].imag /= n;
}

inline void GetConv() {
    FFT_clear();
    rep (i , 0 , lim) A[i] = comp(L[i] + R[i] , L[i] - R[i]);
    FFT(A , N , 1);
    For (i , 0 , N) A[i] = A[i] * A[i];
    FFT(A , N , -1);
    for (int i = 0;i <= lim + lim;i += 2) res[i] = (ll) (A[i].real + 0.5);
}

inline void calc(int p) {
    int t = a[p] + a[p];
    int cur = belong[p];
    rep (i , st[cur] , p - 1) if (t >= a[i])
        ans += R[t - a[i]];
    rep (i , p + 1 , ed[cur]) if (t >= a[i]) {
        ans += cnt_pre[t - a[i]];
        ans += L[t - a[i]];
        if (cur == 1) continue;
    }
    if (p <= ed[1] || p >= st[tot]) return;
    ans += res[t];
}

void solve() {
    init_block();
    FFT_init();

    rep (i , 1 , n) R[a[i]] ++;

    rep (i , 1 , tot) {
        memset(cnt_pre , 0 , sizeof(int) * (lim + 1));
        rep (j , st[i] , ed[i]) cnt_nxt[a[j]] ++;
        rep (j , 0 , lim) R[j] -= cnt[i][j];

        if (i != 1 && i != tot)
            GetConv();

        rep (j , st[i] , ed[i]) cnt_nxt[a[j]] -- , calc(j) , cnt_pre[a[j]] ++;

        rep (j , 0 , lim) L[j] += cnt[i][j];
    }

    printf("%lld\n" , ans);
}

int main() {
    #ifndef ONLINE_JUDGE
        freopen("data.txt" , "r" , stdin);
        freopen("data.out" , "w" , stdout);
    #endif
    input();
    solve();
    return 0;
}

你可能感兴趣的:(分块,FFT)