(N ≤ 30000,A i ≤ 10 5 )
分块,把长度为n的序列分成sqrt(n)块长度为sqrt(n)的序列,然后遍历每一个块分三种情况:
1)三个都在同一个块里面:
暴力枚举后两个,每次维护前面的数的个数,复杂度O(sqrt(n)*n)
2)两个在同一块里面:
暴力枚举块中的两个,维护块前的数的个数和块后的数的个数,复杂度O(sqrt(n)*n)
3)一个在块中,一个在块前,一个在块后
块前块后做fft,枚举块中的数,复杂度O(nlgn*sqrt(n)).
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <cmath> using namespace std; #define pi acos (-1) #define maxn 151111 struct plex { double x, y; plex (double _x = 0.0, double _y = 0.0) : x (_x), y (_y) {} plex operator + (const plex &a) const { return plex (x+a.x, y+a.y); } plex operator - (const plex &a) const { return plex (x-a.x, y-a.y); } plex operator * (const plex &a) const { return plex (x*a.x-y*a.y, x*a.y+y*a.x); } }; void change (plex y[] , int len) { for (int i = 1 , j = len / 2 ; i < len -1 ; i ++) { if (i < j) swap(y[i] , y[j]); int k = len / 2; while (j >= k) { j -= k; k /= 2; } if(j < k) j += k; } } void fft(plex y[],int len,int on) { change(y,len); for(int h = 2; h <= len; h <<= 1) { plex wn(cos(on*2*pi/h),sin(on*2*pi/h)); for(int j = 0;j < len;j+=h) { plex w(1,0); for(int k = j;k < j+h/2;k++) { plex u = y[k]; plex t = w*y[k+h/2]; y[k] = u+t; y[k+h/2] = u-t; w = w*wn; } } } if(on == -1) for(int i = 0;i < len;i++) y[i].x /= len; } int num[maxn], num2[maxn]; long long sum[maxn]; int a[maxn]; plex x[maxn], y[maxn]; int n; int main () { //freopen ("in.txt", "r", stdin); scanf ("%d", &n); for (int i = 0; i < n; i++) { scanf ("%d", &a[i]); } long long ans = 0; int block = sqrt (n), len = n/block; if (len*block != n) block++; //1 都在当前块 for (int t = 0; t < block; t++) { memset (num, 0, sizeof num); for (int i = min ((t+1)*len-1, n-1); i >= t*len; i--) { for (int j = t*len; j < i; j++) num[a[j]]++; for (int j = i-1; j >= t*len; j--) { num[a[j]]--; int cur = 2*a[j]-a[i]; if (cur > 0) ans += num[cur]; } } } //两个在当前块 for (int t = 0; t < block; t++) { memset (num, 0, sizeof num); //第三个在前面的块中 for (int i = 0; i < t*len; i++) num[a[i]]++; for (int i = t*len; i < n && i < (t+1)*len; i++) { for (int j = i+1; j < n && j < (t+1)*len; j++) { int cur = a[i]*2-a[j]; if (cur > 0) ans += num[cur]; } } //第三个在后面的块中 memset (num, 0, sizeof num); for (int i = (t+1)*len; i < n; i++) num[a[i]]++; for (int i = t*len; i < n && i < (t+1)*len; i++) { for (int j = i+1; j < n && j < (t+1)*len; j++) { int cur = a[j]*2-a[i]; if (cur > 0) ans += num[cur]; } } } //只有一个在当前块中 一个在前面 一个在后面 for (int t = 0; t < block; t++) { int cnt1 = 0, cnt2 = 0, Max = 0; memset (num, 0, sizeof num); for (int i = 0; i < t*len; i++) { num[a[i]]++; Max = max (Max, a[i]); } memset (num2, 0, sizeof num2); for (int i = (t+1)*len; i < n; i++) { num2[a[i]]++; Max = max (Max, a[i]); } int l = 1; Max++; while (l < 2*Max) l <<= 1; for (int i = 0; i < l; i++) { x[i] = plex (num[i], 0); } for (int i = 0; i < l; i++) { y[i] = plex (num2[i], 0); } fft (x, l, 1); fft (y, l, 1); for (int i = 0; i < l; i++) x[i] = x[i]*y[i]; fft (x, l, -1); memset (sum, 0, sizeof sum); for (int i = 1; i < l; i++) { sum[i] = (long long) (x[i].x+0.5); } for (int i = t*len; i < (t+1)*len && i < n; i++) { ans += sum[2*a[i]]; } } printf ("%lld\n", ans); return 0; }