给出一棵树,其中有一些点对是不合法的,求一共有多少条路径中是没有不合法的点对。
考虑用所有路径-不合法的路径得出答案。
对于每个点对(x,y),dfn[x] 如果x为y的祖宗,那么y的子树到除了x~y这条链上其它的节点都是不合法的
如果x不为y的祖宗,那么x的子树到y的子树中的点都是不合法的
因为会有重复的不合法点对被计算,我们就把它们扔到坐标系上求面积并。
#pragma GCC optimize(2)
%:pragma GCC optimize(3)
%:pragma GCC optimize("Ofast")
%:pragma GCC optimize("inline")
#include
#include
#include
#include
int dfn[3000001], d[3000001], size[3000001], f[3000001][21];
int ver[6000001], next[6000001], head[6000001];
int n, k, tot, cnt;
long long ans;
inline long long read()
{
long long f = 0, d = 1;
char c;
while (c = getchar(), !isdigit(c)) if (c == '-') d = -1;
f = (f << 3) + (f << 1) + c - 48;
while (c = getchar(), isdigit(c)) f = (f << 3) + (f << 1) + c - 48;
return d * f;
}
struct treenode {
int l, r, len, sum;
}tree[12000001];
struct node {
int l, r, h, mark;
}line[6000001];
bool operator < (node x, node y) {
return x.h < y.h;
}
void build(int p, int l, int r) {
tree[p].l = l;
tree[p].r = r;
tree[p].len = tree[p].sum = 0;
if (l == r) return;
int mid = l + r >> 1;
build(p << 1, l, mid);
build(p << 1 | 1, mid + 1, r);
}
void spread(int p) {
int l = tree[p].l, r = tree[p].r;
if (tree[p].sum) tree[p].len = r - l + 1;
else tree[p].len = tree[p << 1].len + tree[p << 1 | 1].len;
}
void change(int p, int L, int R, int val) {
int l = tree[p].l, r = tree[p].r;
if (r < L || l > R) return;
if (l >= L && r <= R) {
tree[p].sum += val;
spread(p);
return;
}
change(p << 1, L, R, val);
change(p << 1 | 1, L, R, val);
spread(p);
}
void add(int u, int v) {
ver[++tot] = v;
next[tot] = head[u];
head[u] = tot;
}
void dfs(int p) {
size[p] = 1;
dfn[p] = ++dfn[0];
for (int i = head[p]; i; i = next[i]) {
if (d[ver[i]]) continue;
d[ver[i]] = d[p] + 1;
f[ver[i]][0] = p;
for (int j = 1; j <= 20; j++)
f[ver[i]][j] = f[f[ver[i]][j - 1]][j - 1];
dfs(ver[i]);
size[p] += size[ver[i]];
}
}
int LCA(int x, int y) {
for (int i = 20; i >= 0; i--)
if (d[f[y][i]] > d[x]) y = f[y][i];
return y;
}
void addl(int x1, int x2, int y1, int y2) {
if (x1 > x2) std::swap(x1, x2);
if (y1 > y2) std::swap(y1, y2);
line[++cnt] = (node){x1, x2, y1, 1};
line[++cnt] = (node){x1, x2, y2 + 1, -1};
}
void doit(int x, int y) {
if (dfn[x] > dfn[y]) std::swap(x, y);
if (dfn[y] <= dfn[x] + size[x] - 1 && dfn[y] > dfn[x]) {
int son = LCA(x, y);
if (dfn[son] > 1) addl(1, dfn[son] - 1, dfn[y], dfn[y] + size[y] - 1);
if (dfn[son] + size[son] - 1 < n) addl(dfn[y], dfn[y] + size[y] - 1, dfn[son] + size[son], n);
} else addl(dfn[x], dfn[x] + size[x] - 1, dfn[y], dfn[y] + size[y] - 1);
}
int main() {
int size = 256 << 20;
char *p = (char*)malloc(size) + size;
__asm__("movl %0, %%esp\n" :: "r"(p));
n = read();
k = read();
for (int i = 1, x, y; i < n; i++) {
x = read();
y = read();
add(x, y);
add(y, x);
}
d[1] = 1;
dfs(1);
for (int i = 1; i <= n; i++)
for (int j = 1; j <= k && i + j <= n; j++)
doit(i, i + j);
build(1, 1, n);
std::sort(line + 1, line + 1 + cnt);
for (int i = 1; i < cnt; i++) {
change(1, line[i].l, line[i].r, line[i].mark);
ans += (long long)tree[1].len * (line[i + 1].h - line[i].h);
}
printf("%lld", (long long)n * (n - 1) / 2 - ans + n);
}