快速傅里叶变换,可以将多项式相乘的时间复杂度从最简单的O(n^2)优化到O(nlgn),详细过程参考算法导论.
FFT的流程大致是:
1):构造多项式,复杂度O(n)
2):求两个多项式的DFT,复杂度O(nlgn)
3):构造多项式乘积的点值表达式,复杂度O(n)
4):求点值表达式的IDFT,复杂度O(nlgn).
下面是两道最简单的习题:
HDU 1402:点击打开链接
求两个大数乘积.
因为一个大数可以看成是一个多项式,每一位上的值都表示对应次数下的系数,因此可以用FFT加速.
本体的一个坑点就是
len = l1+l2-1;这句代码,可能是精度问题在len更加高位的地方出现了非0值.
#include <bits/stdc++.h> using namespace std; #define pi acos (-1) #define maxn 200010 struct plex { double x, y; plex (double _x = 0.0, double _y = 0.0) : x (_x), y (_y) {} plex operator + (const plex &a) const { return plex (x+a.x, y+a.y); } plex operator - (const plex &a) const { return plex (x-a.x, y-a.y); } plex operator * (const plex &a) const { return plex (x*a.x-y*a.y, x*a.y+y*a.x); } }; void change (plex y[], int len) { if (len == 1) return ; plex a1[len], a2[len]; for (int i = 0; i < len; i += 2) { a1[i/2] = y[i]; a2[i/2] = y[i+1]; } change (a1, len>>1); change (a2, len>>1); for (int i = 0; i < len/2; i++) { y[i] = a1[i]; y[i+len/2] = a2[i]; } return ; } void fft(plex y[],int len,int on) { change(y,len); for(int h = 2; h <= len; h <<= 1) { plex wn(cos(-on*2*pi/h),sin(-on*2*pi/h)); for(int j = 0;j < len;j+=h) { plex w(1,0); for(int k = j;k < j+h/2;k++) { plex u = y[k]; plex t = w*y[k+h/2]; y[k] = u+t; y[k+h/2] = u-t; w = w*wn; } } } if(on == -1) for(int i = 0;i < len;i++) y[i].x /= len; } char a[maxn], b[maxn]; plex x1[maxn], x2[maxn]; int ans[maxn]; int main () { while (scanf ("%s%s", a, b) == 2) { int len = 2, l1 = strlen (a), l2 = strlen (b); while (len < l1*2 || len < l2*2) len <<= 1; for (int i = 0; i < l1; i++) { x1[i] = plex (a[l1-1-i]-'0', 0); } for (int i = l1; i < len; i++) x1[i] = plex (0, 0); for (int i = 0; i < l2; i++) { x2[i] = plex (b[l2-1-i]-'0', 0); } for (int i = l2; i < len; i++) x2[i] = plex (0, 0); fft (x1, len, 1); fft (x2, len, 1); for (int i = 0; i < len; i++) x1[i] = x1[i]*x2[i]; fft (x1, len, -1); for (int i = 0; i < len; i++) { ans[i] = (int)(x1[i].x+0.5); } for (int i = 0; i < len; i++) { if (ans[i] >= 10) { ans[i+1] += ans[i]/10; ans[i] %= 10; } } len = l1+l2-1; while (ans[len] <= 0 && len > 0) len--; for (int i = len; i >= 0; i--) { printf ("%d", ans[i]); } printf ("\n"); } return 0; }
题意是给出n个长度,任取3个求能组成三角形的概率.
首先记录下每个长度的数量,然后用FFT加速求出任取两个长度下的情况,这里面有重复:
首先减去两次都取同一根的情况,减完之后的结果都/2.
最后只需要所有的情况减去不能组成三角形的情况,将最初的长度序列排序后从小到大枚举下标,假设这条边是最长边,那么如果所有两条边长度小于等于这条边的情况就应该减去,这里用前缀和统计下就好了.
#include <bits/stdc++.h> using namespace std; #define pi acos (-1) #define maxn 611111 struct plex { double x, y; plex (double _x = 0.0, double _y = 0.0) : x (_x), y (_y) {} plex operator + (const plex &a) const { return plex (x+a.x, y+a.y); } plex operator - (const plex &a) const { return plex (x-a.x, y-a.y); } plex operator * (const plex &a) const { return plex (x*a.x-y*a.y, x*a.y+y*a.x); } }; void change (plex y[], int len) { if (len == 1) return ; plex a1[len/2], a2[len/2]; for (int i = 0; i < len; i += 2) { a1[i/2] = y[i]; a2[i/2] = y[i+1]; } change (a1, len>>1); change (a2, len>>1); for (int i = 0; i < len/2; i++) { y[i] = a1[i]; y[i+len/2] = a2[i]; } return ; } void fft(plex y[],int len,int on) { change(y,len); for(int h = 2; h <= len; h <<= 1) { plex wn(cos(on*2*pi/h),sin(on*2*pi/h)); for(int j = 0;j < len;j+=h) { plex w(1,0); for(int k = j;k < j+h/2;k++) { plex u = y[k]; plex t = w*y[k+h/2]; y[k] = u+t; y[k+h/2] = u-t; w = w*wn; } } } if(on == -1) for(int i = 0;i < len;i++) y[i].x /= len; } long long num[maxn], sum[maxn]; int a[maxn]; plex x[maxn]; long long n; int main () { //freopen ("in.txt", "r", stdin); int t; scanf ("%d", &t); while (t--) { scanf ("%lld", &n); long long Max = 0; memset (num, 0, sizeof num); for (int i = 1; i <= n; i++) { scanf ("%d", &a[i]); num[a[i]]++; Max = max (Max, (long long)a[i]); } Max++; int len = 2; while (len < Max*2) len <<= 1; for (int i = 0; i < len; i++) { x[i] = plex (num[i], 0); } fft (x, len, 1); for (int i = 0; i < len; i++) { x[i] = x[i]*x[i]; } fft (x, len, -1); for (int i = 0; i < len; i++) { num[i] = (long long) (x[i].x+0.5); } for (int i = 1; i <= n; i++) {//两次取同一个 num[a[i]+a[i]]--; } for (int i = 0; i < len; i++) {//重复计算 num[i] /= 2; } sum[0] = 0; for (int i = 1; i < len; i++) { sum[i] = sum[i-1]+num[i]; } sort (a+1, a+1+n); long long tot = n*(n-1)*(n-2)/6, ans = tot; for (int i = 3; i <= n; i++) { ans -= sum[a[i]]; } printf ("%.7f\n", ans*1.0/tot); } return 0; }