【数据结构 扫描线】空间简单度

题意

给出一棵树,其中有一些点对是不合法的,求一共有多少条路径中是没有不合法的点对。

思路

考虑用所有路径-不合法的路径得出答案。
对于每个点对(x,y),dfn[x] 如果x为y的祖宗,那么y的子树到除了x~y这条链上其它的节点都是不合法的
如果x不为y的祖宗,那么x的子树到y的子树中的点都是不合法的
因为会有重复的不合法点对被计算,我们就把它们扔到坐标系上求面积并。

代码

#pragma GCC optimize(2)
%:pragma GCC optimize(3)
%:pragma GCC optimize("Ofast")
%:pragma GCC optimize("inline")
#include 
#include 
#include 
#include 

int dfn[3000001], d[3000001], size[3000001], f[3000001][21];
int ver[6000001], next[6000001], head[6000001];
int n, k, tot, cnt;
long long ans;

inline long long read()
{
	long long f = 0, d = 1;
	char c; 
	while (c = getchar(), !isdigit(c)) if (c == '-') d = -1;
	f = (f << 3) + (f << 1) + c - 48;
	while (c = getchar(), isdigit(c)) f = (f << 3) + (f << 1) + c - 48;
	return d * f;
}

struct treenode {
	int l, r, len, sum;
}tree[12000001];

struct node {
	int l, r, h, mark;
}line[6000001];

bool operator < (node x, node y) {
	return x.h < y.h;
}

void build(int p, int l, int r) {
	tree[p].l = l;
	tree[p].r = r;
	tree[p].len = tree[p].sum = 0;
	if (l == r) return;
	int mid = l + r >> 1;
	build(p << 1, l, mid);
	build(p << 1 | 1, mid + 1, r);
}

void spread(int p) {
	int l = tree[p].l, r = tree[p].r;
	if (tree[p].sum) tree[p].len = r - l + 1;
	else tree[p].len = tree[p << 1].len + tree[p << 1 | 1].len;
}

void change(int p, int L, int R, int val) {
	int l = tree[p].l, r = tree[p].r;
	if (r < L || l > R) return;
	if (l >= L && r <= R) {
		tree[p].sum += val;
		spread(p);
		return;
	}
	change(p << 1, L, R, val);
	change(p << 1 | 1, L, R, val);
	spread(p);
}

void add(int u, int v) {
	ver[++tot] = v;
	next[tot] = head[u];
	head[u] = tot;
}

void dfs(int p) {
	size[p] = 1;
	dfn[p] = ++dfn[0];
    for (int i = head[p]; i; i = next[i]) {
    	if (d[ver[i]]) continue;
    	d[ver[i]] = d[p] + 1;
    	f[ver[i]][0] = p;
    	for (int j = 1; j <= 20; j++)
    		f[ver[i]][j] = f[f[ver[i]][j - 1]][j - 1];
    	dfs(ver[i]);
    	size[p] += size[ver[i]];
	}
}

int LCA(int x, int y) {
    for (int i = 20; i >= 0; i--)
		if (d[f[y][i]] > d[x]) y = f[y][i];
	return y;
}

void addl(int x1, int x2, int y1, int y2) {
	if (x1 > x2) std::swap(x1, x2);
	if (y1 > y2) std::swap(y1, y2);
	line[++cnt] = (node){x1, x2, y1, 1};
	line[++cnt] = (node){x1, x2, y2 + 1, -1};
}

void doit(int x, int y) {
	if (dfn[x] > dfn[y]) std::swap(x, y);
	if (dfn[y] <= dfn[x] + size[x] - 1 && dfn[y] > dfn[x]) {
		int son = LCA(x, y);
		if (dfn[son] > 1) addl(1, dfn[son] - 1, dfn[y], dfn[y] + size[y] - 1);
		if (dfn[son] + size[son] - 1 < n) addl(dfn[y], dfn[y] + size[y] - 1, dfn[son] + size[son], n);
	} else addl(dfn[x], dfn[x] + size[x] - 1, dfn[y], dfn[y] + size[y] - 1);
}

int main() {
	int size = 256 << 20;
    char *p = (char*)malloc(size) + size;
    __asm__("movl %0, %%esp\n" :: "r"(p));
	n = read();
	k = read();
	for (int i = 1, x, y; i < n; i++) {
		x = read();
		y = read();
		add(x, y);
		add(y, x);
	}
	d[1] = 1;
	dfs(1);
	for (int i = 1; i <= n; i++)
		for (int j = 1; j <= k && i + j <= n; j++)
			doit(i, i + j);
	build(1, 1, n);
	std::sort(line + 1, line + 1 + cnt);
	for (int i = 1; i < cnt; i++) {
		change(1, line[i].l, line[i].r, line[i].mark);
		ans += (long long)tree[1].len * (line[i + 1].h - line[i].h);
	}
	printf("%lld", (long long)n * (n - 1) / 2 - ans + n);
}

你可能感兴趣的:(数据结构)