比赛没有仔细想,码了个 n2 暴力,结果还被卡常了。
暴力思路如下:
先枚举根为x,依次加入x+1,x+2,x+3……,加入一个点y(y>x)时,如果它的子树中没有已经加入的点,那么它就会使距离增大,再暴力往上跳,给它的祖先打上标记,直到到某个点,这个点已经被打上标记就可以停止了,同时可以算出增加的距离。
感觉这个思路很想GDOI-2017-D1-T2那种,因为每个点只会被打上一次标记,所以复杂度可以保证。
然而我忽略了这种树上和边有关的计数类问题必须要想到一种思路:对每条边的贡献单独考虑。
我们可以枚举一条边,给这条边分成的两个集合的点黑白染色,再反映到1-n的序列上。
这条边的贡献等于 n∗(n+1)/2−∑s∗(s+1)/2 ,其中s是各个连续相同颜色块的长度。
当然我们也可以随便选择一个根建树,对于每条边,只给它的深度大的端点的子树染色,利用线段树就可以直接求答案了。
既然是给子树染色,那么树上启发式用可以用上了。
老套路:先搞轻儿子,清空线段树,再搞重儿子,保留线段树,再暴力把轻儿子的子树和自己加进线段树,最后判断是否子树在线段树里的东西。
撤销操作就暴力染回去。
时间复杂度: O(n log n log n)
Code:
#include
#include
#define ll long long
#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define sqr(a) ((a) * (a + 1) / 2)
using namespace std;
const int N = 100005;
const ll mo = 1e9 + 7;
int n, x, y;
int final[N], tot;
struct edge {
int to, next;
}e[N * 2];
int fa[N], bz[N], siz[N], son[N];
ll ans;
ll ksm(ll x, ll y) {
ll s = 1;
for(; y; x = x * x % mo, y >>= 1)
if(y & 1) s = s * x % mo;
return s;
}
struct tree {
int l, r; ll sl, sr, s;
}t[N * 10];
void link(int x, int y) {
e[++ tot].next = final[x], e[tot].to = y, final[x] = tot;
e[++ tot].next = final[y], e[tot].to = x, final[y] = tot;
}
void dg(int x) {
bz[x] = 1;
siz[x] = 1;
for(int i = final[x]; i; i = e[i].next) {
int y = e[i].to; if(bz[y]) continue;
fa[y] = x; dg(y); siz[x] += siz[y];
son[x] = siz[y] > siz[son[x]] ? y : son[x];
}
bz[x] = 0;
}
void Bin(int i, int l, int r) {
int x = i + i, y = i + i + 1;
int m = (l + r) / 2;
t[i].l = t[x].l; t[i].r = t[y].r;
t[i].s = t[x].s + t[y].s;
if(t[x].r == t[y].l) t[i].s -= sqr(t[x].sr) + sqr(t[y].sl) - sqr(t[x].sr + t[y].sl);
if(t[x].r == t[y].l && t[x].sl == (m - l + 1))
t[i].sl = t[x].sl + t[y].sl; else t[i].sl = t[x].sl;
if(t[x].r == t[y].l && t[y].sl == (r - m))
t[i].sr = t[y].sr + t[x].sr; else t[i].sr = t[y].sr;
}
void Build(int i, int x, int y) {
if(x == y) {
t[i].s = 1;
t[i].l = t[i].r = 0;
t[i].sl = t[i].sr = 1;
return;
}
int m = (x + y) / 2;
Build(i + i, x, m); Build(i + i + 1, m + 1, y);
Bin(i, x, y);
}
void change(int i, int x, int y, int l, int r) {
if(x == y) {
t[i].s = 1;
t[i].l = t[i].r = r;
t[i].sl = t[i].sr = 1;
return;
}
int m = (x + y) / 2;
if(l <= m) change(i + i, x, m, l, r); else change(i + i + 1, m + 1, y, l, r);
Bin(i, x, y);
}
void dfs(int x) {
bz[x] = 1;
change(1, 1, n, x, 1);
for(int i = final[x]; i; i = e[i].next) {
int y = e[i].to; if(bz[y]) continue;
dfs(y);
}
bz[x] = 0;
}
void Clear(int x) {
change(1, 1, n, x, 0);
bz[x] = 1;
for(int i = final[x]; i; i = e[i].next) {
int y = e[i].to; if(bz[y]) continue;
Clear(y);
}
bz[x] = 0;
}
void dg1(int x) {
bz[x] = 1;
for(int i = final[x]; i; i = e[i].next) {
int y = e[i].to; if(bz[y] || y == son[x]) continue;
dg1(y);
}
if(son[x]) dg1(son[x]);
for(int i = final[x]; i; i = e[i].next) {
int y = e[i].to; if(bz[y] || y == son[x]) continue;
dfs(y);
}
change(1, 1, n, x, 1);
ans += (ll)n * (n + 1) / 2 - t[1].s;
if(son[fa[x]] != x) Clear(x);
bz[x] = 0;
}
int main() {
freopen("communicate.in", "r", stdin);
freopen("communicate.out", "w", stdout);
scanf("%d", &n);
fo(i, 1, n - 1) {
scanf("%d %d", &x, &y);
link(x, y);
}
dg(1);
Build(1, 1, n);
dg1(1);
ans = ans * 2 % mo * ksm((ll)n * (n + 1) / 2, mo - 2) % mo;
printf("%lld", ans);
}