转载请注明出处,谢谢http://blog.csdn.net/ACM_cxlove?viewmode=contents by---cxlove
太逗了。。。比赛的时候,误以为素数会很多。。。。
然后 就想歪 了,开始搞FFT。
其实发现主要 是a + b + c的情况不好处理。
先将a + b的情况FFT一下,然后 再 + c FFT一次。
num[i]表示a + b = i的个数, sum[i] 表示 a + b + c = i的个数。
但是去重比较麻烦,如果 a + a = i的情况最好单独考虑,从num[i]中去掉,查询的时候单独统计。
这样的话和c合并,可能出现a + b + a的情况,这种情况也需要在查询的时候单调搞一下,最后再去掉排列组合的情况。
赛后发现只有8000个素数,所有情况平方都可以预处理
#include <iostream> #include <cstdio> #include <cstring> #include <cmath> #include <algorithm> #include <set> #include <map> #include <vector> #include <queue> #include <stack> #define lson step << 1 #define rson step << 1 | 1 #define lowbit(x) (x & (-x)) #define Key_value ch[ch[root][1]][0] using namespace std; typedef long long LL; const int N = 80005; const LL MOD = 1000000007; int prime[N] , flag[N] , cnt = 0; void init () { flag[0] = flag[1] = 1; for (int i = 2 ; i < N ; i ++) { if (flag[i]) continue; prime[cnt ++] = i; for (int j = 2 ; j * i < N ; j ++) flag[i * j] = 1; } } //FFT copy from kuangbin const double pi = acos (-1.0); // Complex z = a + b * i struct Complex { double a, b; Complex(double _a=0.0,double _b=0.0):a(_a),b(_b){} Complex operator + (const Complex &c) const { return Complex(a + c.a , b + c.b); } Complex operator - (const Complex &c) const { return Complex(a - c.a , b - c.b); } Complex operator * (const Complex &c) const { return Complex(a * c.a - b * c.b , a * c.b + b * c.a); } }; //len = 2 ^ k void change (Complex y[] , int len) { for (int i = 1 , j = len / 2 ; i < len -1 ; i ++) { if (i < j) swap(y[i] , y[j]); int k = len / 2; while (j >= k) { j -= k; k /= 2; } if(j < k) j += k; } } // FFT // len = 2 ^ k // on = 1 DFT on = -1 IDFT void FFT (Complex y[], int len , int on) { change (y , len); for (int h = 2 ; h <= len ; h <<= 1) { Complex wn(cos (-on * 2 * pi / h), sin (-on * 2 * pi / h)); for (int j = 0 ; j < len ; j += h) { Complex w(1 , 0); for (int k = j ; k < j + h / 2 ; k ++) { Complex u = y[k]; Complex t = w * y [k + h / 2]; y[k] = u + t; y[k + h / 2] = u - t; w = w * wn; } } } if (on == -1) { for (int i = 0 ; i < len ; i ++) { y[i].a /= len; } } } LL sum[N << 2] , num[N << 2]; Complex x1[N << 2] , x2[N << 2]; int main (){ int n = 80000; init (); int len = n; int l = 1; while (l < len * 2) l <<= 1; for (int i = 0 ; i <= n ; i ++) { if (flag[i] == 0) x1[i] = Complex (1 , 0); else x1[i] = Complex (0 , 0); } for (int i = n + 1 ; i < l ; i ++) x1[i] = Complex (0 , 0); FFT(x1 , l , 1); for (int i = 0 ; i < l ; i ++) { x1[i] = x1[i] * x1[i]; } FFT(x1 , l , -1); for (int i = 0 ; i <= n ; i ++) { num[i] = (LL)(x1[i].a + 0.5); } for (int i = 0 ; i <= n ; i ++) { if (flag[i] == 0) num[i * 2] --; } for (int i = 0 ; i <= n ; i ++) { num[i] /= 2; } for (int i = 0 ; i <= n ; i ++) { if (flag[i] == 0) x1[i] = Complex (1 , 0); else x1[i] = Complex (0 , 0); } for (int i = n + 1 ; i < l ; i ++) x1[i] = Complex (0 , 0); for (int i = 0 ; i <= n ; i ++) { x2[i] = Complex (num[i] , 0); } for (int i = n + 1 ; i < l ; i ++) x2[i] = Complex (0 , 0); FFT(x1 , l , 1); FFT(x2 , l , 1); for (int i = 0 ; i < l ; i ++) { x1[i] = x1[i] * x2[i]; } FFT(x1 , l , -1); for (int i = 0 ; i <= n ; i ++) { sum[i] = (LL)(x1[i].a + 0.5); } while (scanf ("%d" , &n) != EOF) { LL ans = flag[n] == 0 ? 1 : 0; // a for (int i = 0 ; i < cnt && prime[i] * prime[i] <= n ; i ++) { // a * b if (n % prime[i]) continue; if (flag[n / prime[i]] == 0) { ans ++; } } for (int i = 0 ; i < cnt && prime[i] <= n ; i ++) { for (int j = i ; j < cnt && 1LL * prime[i] * prime[j] <= n ; j ++) { if (flag[n - prime[i] * prime[j]] == 0) { ans ++; // a * b + c } } } for (int i = 0 ; i < cnt && prime[i] <= n ; i ++) { for (int j = i ; j < cnt && 1LL * prime[i] * prime[j] <= n ; j ++) { for (int k = j ; k < cnt && 1LL * prime[i] * prime[j] * prime[k] <= n ; k ++) { if (prime[i] * prime[j] * prime[k] == n) ans ++; // a * b * c } } } int tot ; if (n % 2 == 0 && flag[n / 2] == 0) tot = 1; else tot = 0; ans = (ans + (LL)num[n] + tot) % MOD; // a + b tot = 0; for (int i = 0 ; i < cnt ; i ++) { if (n - prime[i] * 2 <= 0) break; if (flag[n - prime[i] * 2] == 0) tot ++; } ans = (ans + (sum[n] + 2 * tot) / 3) % MOD; // a + b + c if (n % 3 == 0 && flag[n / 3] == 0) ans ++; printf ("%lld\n" , ans); } return 0; }