烂大街的题了,练一下手而已,但是也暴露出很多的问题,然后我一直把时间优化到1s左右才罢手
http://www.codechef.com/problems/PRIMEDST
复杂度为nlogn^2,因为每次合并统计的时候用了FFT,是nlogn
顺带提一句,编译器的差别好大。。
#include <cstdio> #include <cstring> #include <algorithm> #include <cmath> #include <vector> using std::vector; typedef long long type; struct comp{ double x, y; comp(double _x=0, double _y=0) : x(_x), y(_y) {} }; namespace FFT{ const int N = 131072, MinSize = 400000; const double pi2 = 3.1415926535897932 * 2; comp a[N], b[N], tmp[N]; int n, bn; type res[N]; inline comp W(int n, bool inv) { double ang = inv ? -pi2 / n : pi2 / n; return comp(cos(ang), sin(ang)); } inline int bitrev(int x) { int ans = 0; for (int i=1; i<=bn; ++i) ans <<= 1, ans |= x & 1, x >>= 1; return ans; } void dft(comp *a,bool inv) { int step, to; comp w, wi, A, B; for (int i=0; i<n; ++i) { to = bitrev(i); if (to > i) std::swap(a[to], a[i]); } for (int i=1; i<=bn; ++i) { wi = W(1<<i, inv); w = comp(1, 0); step = 1 << (i-1); for (int k=0; k<step; ++k) { for (int j=0; j<n; j+=1<<i) { int t = j | k, d = j|k|step; A = a[t]; B.x = w.x * a[d].x - w.y * a[d].y; B.y = w.x * a[d].y + w.y * a[d].x; a[t].x = A.x + B.x, a[t].y = A.y + B.y; a[d].x = A.x - B.x, a[d].y = A.y - B.y; } comp tmp; tmp.x = w.x * wi.x - w.y * wi.y; tmp.y = w.x * wi.y + w.y * wi.x; w = tmp; } } } int mul(int n1, int *x1, int n2, int *x2) { n = std::max(n1, n2); for (bn = 0; (1<<bn) < n; ++bn); ++bn; n = 1 << bn; for (int i=0; i<n1; ++i) a[i] = comp(x1[i], 0); for(int i=n1;i<n;i++) a[i] = comp(0,0); dft(a, false); for (int i=0; i<n; ++i) { tmp[i].x = a[i].x * a[i].x - a[i].y * a[i].y; tmp[i].y = a[i].x * a[i].y + a[i].y * a[i].x; } dft(tmp, true); for (int i=0; i<n; ++i) res[i] = (type)(tmp[i].x/n + 0.1); for (--n; n && !res[n]; --n); return n+1; } } const int N = 50010; bool vis[N]; int p[N],pn; void init() { pn = 0; for(int i = 2; i < N; i++) { for(int j = i+i; j < N; j+=i) { vis[j] = true; } } for(int i = 2; i < N; i++) if(!vis[i]) p[pn++] = i; } int head[N],nxt[N*2],pnt[N*2]; int E,n; void add(int a,int b) { pnt[E] = b; nxt[E] = head[a]; head[a] = E++; } bool del[N]; int son[N],opt[N]; vector<int> alln; void dfs(int u,int f) { alln.push_back(u); son[u] = 1; opt[u] = 0; for(int i = head[u]; i!=-1; i = nxt[i]) if(pnt[i]-f){ if(del[pnt[i]]) continue; dfs(pnt[i],u); son[u] += son[pnt[i]]; opt[u] = std::max(opt[u],son[pnt[i]]); } } int getcenter(int u) { alln.clear(); dfs(u,-1); int mx = 0, ans = -1; int sz = alln.size(); for(int i = 0; i < sz; i++) { int v = alln[i]; if(ans == -1) ans = v, mx = std::max(opt[v],sz-son[v]); else { if(std::max(opt[v],sz-son[v]) < mx) { mx = std::max(opt[v],sz-son[v]); ans = v; } } } return ans; } int tot; int D[N]; void getdist(int u,int f,int prew) { D[tot++] = prew; for(int i = head[u]; i != -1; i = nxt[i]) if(pnt[i]-f) { if(del[pnt[i]]) continue; getdist(pnt[i],u,prew+1); } } int cnt[50010]; inline long long calc() { // calcluate how many pair's sum of D[] is a prime int len = *std::max_element(D,D+tot) + 1; std::fill(cnt,cnt+len,0); for(int i = 0; i < tot; i++) cnt[D[i]] ++; len = FFT::mul(len,cnt,len,cnt); for(int i = 0; i < tot; i++) FFT::res[D[i]+D[i]]--; for(int i = 0; i<pn && p[i] < len ; i++) FFT::res[p[i]] /= 2; long long ans = 0; for(int i = 0; i<pn && p[i] < len ; i++) ans += FFT::res[p[i]]; return ans; } long long ans; void solve(int u) { u = getcenter(u); tot=0;getdist(u,-1,0); ans += calc(); for(int i = head[u]; i != -1; i = nxt[i]) { if(del[pnt[i]]) continue; tot = 0; getdist(pnt[i],u,1); ans -= calc(); } del[u] = true; for(int i = head[u]; i != -1; i = nxt[i]) { if(del[pnt[i]]) continue; solve(pnt[i]); } } int main(){ init(); while(scanf("%d",&n)!=EOF) { E = 0; std::fill(head,head+n+1,-1); std::fill(del,del+n+1,false); for(int i = 1,a,b; i < n; i++) { scanf("%d%d",&a,&b); add(a,b); add(b,a); } ans=0; solve(1); printf("%.8f\n",1.0*ans*2/n/(n-1)); } return 0; }