CodeChef PrimeDST【点分治】【FFT】

/* I will wait for you */

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <ctime>
#include <algorithm>
#include <iostream>
#include <fstream>
#include <vector>
#include <queue>
#include <deque>
#include <set>
#include <map>
#include <string>
#define make(a,b) make_pair(a,b)
#define fi first
#define se second

using namespace std;

typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
typedef map<int, int> mii;

const int maxn = 200010;
const int maxm = 1010;
const int maxs = 26;
const int inf = 0x3f3f3f3f;
const int P = 1000000007;
const double error = 1e-9;
const double Pi = 3.1415926535897932;

inline ll read()
{
	ll x = 0, f = 1; char ch = getchar();
	while (ch > '9' || ch < '0' )
		f = (ch == '-' ? -1 : 1), ch = getchar();
	while (ch <= '9' && ch >= '0')
		x = x * 10 + ch - '0', ch = getchar();
	return f * x;
}

struct complex
{
	double re, im;
} _x[maxn], w[2][maxn];

complex operator + (complex a, complex b)
{
	complex c;
	c.re = a.re + b.re;
	c.im = a.im + b.im;
	return c;
}

complex operator - (complex a, complex b)
{
	complex c;
	c.re = a.re - b.re;
	c.im = a.im - b.im;
	return c;
}

complex operator * (complex a, complex b)
{
	complex c;
	c.re = a.re * b.re - a.im * b.im;
	c.im = a.re * b.im + a.im * b.re;
	return c;
}

struct edge
{
	int v, next;
} e[maxn];

int n, root, sum, _maxdeep, maxdeep, head[maxn], 
    pri[maxn], _max[maxn], size[maxn], cnt, g[maxn], 
    f[maxn], deep[maxn], del[maxn], rev[maxn];
ll ans;

void insert(int u, int v)
{
	e[cnt] = (edge) {v, head[u]}, head[u] = cnt++;
}	

void _deep(int u, int p)
{
	g[deep[u]] += 1, maxdeep = max(maxdeep, deep[u]);
	for (int i = head[u]; ~i; i = e[i].next) {
		int v = e[i].v;
		if (v != p && !del[v])
			deep[v] = deep[u] + 1, _deep(v, u);
	}
}

void _size(int u, int p)
{
	size[u] = 1;
	for (int i = head[u]; ~i; i = e[i].next) {
		int v = e[i].v;
		if (v != p && !del[v])
			_size(v, u), size[u] += size[v];
	}
}

void _find(int u, int p)
{
	size[u] = 1, _max[u] = 0;
	for (int i = head[u]; ~i; i = e[i].next) {
		int v = e[i].v;
		if (v != p && !del[v]) {
			_find(v, u), size[u] += size[v];
			_max[u] = max(_max[u], size[v]);
		}
	}
	_max[u] = max(_max[u], sum - size[u]);
	if (_max[u] < _max[root]) root = u;
}

void FFT(complex *a, int n, int f)
{
	for (int i = 0; i < n; i++) {
		rev[i] = 0;
		for (int j = i, k = 1; k < n; k <<= 1, j >>= 1)
			(rev[i] <<= 1) |= (j & 1);
		if (rev[i] > i) swap(a[i], a[rev[i]]);
	}

	for (int i = 0; i < n; i++) {
		w[0][i].re = cos(2 * Pi * i / n);
		w[0][i].im = sin(2 * Pi * i / n);
		w[1][i].re = cos(2 * Pi * i / n);
		w[1][i].im = -sin(2 * Pi * i / n);
	}

	for (int i = 1; i < n; i <<= 1)
		for (int j = 0, l = n / (i << 1); j < n; j += (i << 1))
			for (int k = 0, t = 0; k < i; k += 1, t += l) {
				complex x = a[j + k], y = w[f][t] * a[i + j + k];
				a[j + k] = x + y, a[i + j + k] = x - y;
			}
	for (int i = 0; f && i < n; i++) a[i].re /= n;
}

void _solve(int *a, int n, int f)
{
	int len = 1; 
	while (1 << len < n << 1) len += 1;

	for (int i = 0; i < 1 << len; i++)
		_x[i].re = _x[i].im = 0;
	for (int i = 0; i < n; i++)
		_x[i].re = a[i];

	FFT(_x, 1 << len, 0);
	for (int i = 0; i < 1 << len; i++)
		_x[i] = _x[i] * _x[i];
	FFT(_x, 1 << len, 1);

	for (int i = 0; i < 1 << len; i++) {
		if (!pri[i] && f == 1)
			ans += (ll) (_x[i].re + 0.5);
		if (!pri[i] && f == -1)
			ans -= (ll) (_x[i].re + 0.5);
	}
}

void solve(int u)
{
	del[u] = 1, _size(u, 0);
	f[0] = 1, _maxdeep = 0;
	for (int i = 1; i <= size[u]; i++)
		f[i] = 0;

	for (int i = head[u]; ~i; i = e[i].next) {
		int v = e[i].v;
		if (!del[v]) {
			maxdeep = 0;
			for (int i = 0; i <= size[v]; i++)
				g[i] = 0;

			deep[v] = 1, _deep(v, u);
			_solve(g, maxdeep + 1, -1);

			_maxdeep = max(_maxdeep, maxdeep);
			for (int i = 0; i <= size[v]; i++)
				f[i] += g[i];
		}
	}
	_solve(f, _maxdeep + 1, 1);
	
	for (int i = head[u]; ~i; i = e[i].next) {
		int v = e[i].v;
		if (!del[v]) {
			sum = size[v], root = 0;
			_find(v, u), solve(root);
		}
	}

}

void init()
{
	pri[0] = pri[1] = 1;
	for (int i = 2; i < n; i++)
		if (!pri[i]) 
			for (int j = 2 * i; j < n; j += i)
				pri[j] = 1;
}

int main()
{
	n = read(), init();

	memset(head, -1, sizeof head);	
	for (int i = 1; i < n; i++) {
		int u = read(), v = read();
		insert(u, v), insert(v, u);
	}

	sum = _max[0] = n, root = 0;
	_find(1, 0), solve(root);

	printf("%.6f\n", 1.0 * ans / n / (n - 1));
	
	return 0;
}

你可能感兴趣的:(CodeChef PrimeDST【点分治】【FFT】)