算法 | 虚树学习笔记

虚树学习笔记

blog

  • 虚树是一棵虚拟构建的树…废话 这棵树只包含关键点和关键的点,而其他不影响虚树结构的点和边都相当于进行了路径压缩…而且整棵虚树的大小不会超过关键点的2倍

    • 举个例子?
    • 比方说…4 5 6 7 表示关键点(现在你要让这些关键点也形成一棵树…要求点最少但不能破坏原来树的结构)算法 | 虚树学习笔记_第1张图片
    • 构建完的虚树长这个样(其实这也是构建虚树时候的最坏情况…所有原来树上的点都加入了虚树之中)算法 | 虚树学习笔记_第2张图片
    • 什么时候才会出现这种最坏的情况???当所有叶子节点都是关键点的时候…,但是很显然的是叶子结点不会超过 n 2 \dfrac{n}{2} 2n
  • 如何构建虚树?

    • 预处理出dfs序…将关键点按照dfs序进行排序

    • 我们使用一个,从栈顶到栈底的元素形成虚树的一条树链.

    • Code

      • 	sort(af + 1, af + 1 + num, cmp);
        	if(con[1] != 1) st[++top] = 1;
        	for(int i = 1; i <= num; ++i) {
        		int pos = af[i], fa = 0;
        		for(; top;) {
        			fa = GetLCA(st[top], pos);
        			if(top > 1 && dep[fa] < dep[st[top - 1]]) {
        				add(st[top - 1], st[top]), top--;
        			} else if(dep[fa] < dep[st[top]]) {
        				add(fa, st[top]), top--;
        				break;
        			} else {
        				break;
        			}
        		}
        		if(st[top] != fa) 
        			st[++top] = fa;
        		st[++top] = pos;
        	}
        	for(; top > 1; --top)
        		add(st[top - 1], st[top]);
        
  • 正确性✅?

    • 对于任意指定两点 a , b a,b a,b l c a lca lca ,都存在 d f s dfs dfs 序连续的两点 u , v u,v u,v( d f n [ u ] ≤ d f n [ v ] dfn[u]\leq dfn[v] dfn[u]dfn[v]) 分别属于 l c a lca lca包含 a , b a,b a,b 的两棵子树,此时这 v v v 加入时按照上面的操作必定会把 l c a lca lca加入栈,所以应当加入的点都加入了。对于非 l c a lca lca 点,按照上面操作是不会出现这个点的。

HDU6035 Colorful tree

  • 题意:

    • 给你一棵 n ( 1 ≤ n ≤ 200000 ) n(1\leq n\leq 200000) n(1n200000)个节点的树…每一个节点有一个颜色.一条路径的权值定义为出现颜色集合的大小…问所有路径的权值和是多少…
  • 题解:

    当我们在考虑这个问题怎么求的时候…我们很容易地考虑到要计算每个颜色的贡献

    我们会发现这样正着做会很难…遇难则反

    • 问题转换为求有多少条路径没有出现过颜色 c c c

    事实上我们并不需要构建出每一种颜色的虚树… 只需一遍dfs就可以考虑啦!

    但是应用到了虚树的思想…

    首先一个颜色没有 c o l col col的联通块里任意两个点的路径上都不会出现 c o l col col这个颜色

    我们记录一个 s u m [ x ] sum[x] sum[x]表示遍历到当前这个点 颜色为 x x x的最高子树总大小。

    • 算法 | 虚树学习笔记_第3张图片

    假设当前点为 n o w now now,那么接下去就要遍历它的儿子…

    • 我们记下遍历每个儿子之前的 s u m [ c o l [ n o w ] ] sum[col[now]] sum[col[now]]值为 l a s t last last
    • 那么遍历完当前儿子之后 a d d = s u m [ c o l [ n o w ] ] − l a s t add=sum[col[now]] - last add=sum[col[now]]last就是这个儿子顶部颜色为 c o l [ n o w ] col[now] col[now]的个数
    • 那么 s z [ s o n ] − a d d sz[son]-add sz[son]add就是颜色不为 c o l [ n o w ] col[now] col[now]的联通块大小
  • Code

    #include 
    #define ll long long
    using namespace std;
    const int N = 2e5 + 10;
    struct data {
    	int nt, to;
    } a[N << 1];
    ll ans;
    int sz[N], vis[N], head[N], col[N], sum[N], cnt = 0;
    
    void add(int x, int y) {
    	a[++cnt].to = y;
    	a[cnt].nt = head[x];
    	head[x] = cnt;
    }
    
    ll js(ll x) {
    	return x * (x - 1) / 2;
    }
    
    void dfs(int u, int fa) {
    	sz[u] = 1;
    	int son = 0;
    	for(int i = head[u]; i; i = a[i].nt) {
    		int to = a[i].to;
    		if(to == fa) {
    			continue;
    		}
    		int last = sum[col[u]];
    		dfs(to, u);
    		sz[u] += sz[to];
    		int now = sum[col[u]] - last;
    		ans -= js(sz[to] - now);
    		son += sz[to] - now;
    	}
    	sum[col[u]] += son + 1;
    }
    
    int main() {
    	int n, cas = 0;
    	while(scanf("%d", &n) == 1) {
    		cnt = ans = 0;
    		int tot = 0;
    		for(int i = 1; i <= n; ++i) {
    			scanf("%d", &col[i]);
    			if(!vis[col[i]]) {
    				tot++, vis[col[i]] = 1;
    			}
    		}
    		for(int i = 1, x, y; i < n; ++i) {
    			scanf("%d%d", &x, &y);
    			add(x, y), add(y, x);
    		}
    		ans = 1ll * tot * js(n);
    		dfs(1, 0);
    		for(int i = 1; i <= n; ++i) {
    			head[i] = 0;
    			if(vis[col[i]]) {
    				// printf("%d %d\n", col[i], sum[col[i]]);
    				ans -= js(n - sum[col[i]]);
    				vis[col[i]] = 0;
    				sum[col[i]] = 0;
    			}
    		}
    		++cas;
    		printf("Case #%d: %lld\n", cas, ans);
    	}
    	return 0;
    }
    

HNOI2014 | 世界树

  • 题意

    • 给你一棵 n ( 1 ≤ n ≤ 100000 ) n(1\leq n\leq100000) n(1n100000)个节点的树 每次给你 n u m ( ∑ n u m ≤ 100000 ) num(\sum num \leq 100000) num(num100000)个控制点

      • 一个点的控制点是离它最近的且标号最小的点

      问你每个控制点控制了多少个点

  • 题解

    首先肯定是对这个 n u m num num个点构建虚树

    先预处理出虚数上的点的控制点 在去更新原树上的控制点

    算法 | 虚树学习笔记_第4张图片

    (蓝色点为虚树上的点)有这么两种情况

    • 这两个点的控制点是一样(假设为 c o n [ x ] con[x] con[x]) 那么这两个点之间的所有点的控制点显然都是 c o n [ x ] con[x] con[x]

    • 这两个点的控制点是不一样的假设为 c o n [ x ] , c o n [ y ] con[x],con[y] con[x],con[y]

      首先倍增出 m i d mid mid点 再根据 c o n [ x ] , c o n [ y ] con[x],con[y] con[x],con[y]的大小判断距离相同时的情况

  • Code

    #include 
    using namespace std;
    const int N = 3e5 + 10, LOG = 20;
    struct data {
    	int nt, to;
    } a[N << 1];
    int con[N], f[N], be[N], af[N], st[N], rem[N], c[N];
    int head[N], g[N][LOG + 1], dep[N], sz[N], dfn[N];
    int n, cnt = 0, res = 0, now = 0, top = 0;
    
    void add(int x, int y) {
    	a[++cnt].to = y;
    	a[cnt].nt = head[x];
    	head[x] = cnt;
    }
    
    void predfs(int x, int fa) {
    	dfn[x] = ++res;
    	dep[x] = dep[fa] + 1;
    	sz[x] = 1, g[x][0] = fa;
    	for(int i = head[x]; i; i = a[i].nt) {
    		int to = a[i].to;
    		if(to == fa) {
    			continue;
    		}
    		predfs(to, x);
    		sz[x] += sz[to];
    	}
    }
    
    void prepare() {
    	for(int j = 1; j <= LOG; ++j)
    		for(int i = 1; i <= n; ++i)
    			g[i][j] = g[g[i][j - 1]][j - 1];
    }
    
    int GetLCA(int A, int B) {
    	if(dep[A] > dep[B]) 
    		swap(A, B);
    	for(int i = LOG; i >= 0; --i)
    		if(dep[g[B][i]] >= dep[A])
    			B = g[B][i];
    	if(A == B)
    		return A;
    	for(int i = LOG; i >= 0; --i)
    		if(g[A][i] != g[B][i])
    			A = g[A][i], B = g[B][i];
    	return g[A][0];
    }
    
    int dis(int x, int y) {
    	return dep[x] + dep[y] - 2 * dep[GetLCA(x, y)];
    }
    
    void dfs(int x) {
    	c[++now] = x, rem[x] = sz[x];
    	for(int i = head[x]; i; i = a[i].nt) {
    		int to = a[i].to;
    		dfs(to);
    		if(!con[to]) continue;
    		int d1 = dis(x, con[to]), d2 = dis(x, con[x]);
    		if((d1 == d2 && (con[to] < con[x])) || d1 < d2 || !con[x])
    			con[x] = con[to];
    	}
    }
    
    void Dfs(int x) {
    	if(!con[x]) return ;
    	for(int i = head[x]; i; i = a[i].nt) {
    		int to = a[i].to;
    		int d1 = dis(to, con[to]), d2 = dis(to, con[x]);
    		if((d1 == d2 && (con[x] < con[to])) || d1 > d2 || !con[to])
    			con[to] = con[x];
    		Dfs(to);
    	}
    }
    
    void work(int x, int y) {
    	int ny = y, mid = y;
    	for(int i = LOG; i >= 0; --i)
    		if(dep[g[ny][i]] > dep[x])
    			ny = g[ny][i];
    	rem[x] -= sz[ny];
    	if(con[x] == con[y]) {
    		f[con[x]] += sz[ny] - sz[y];
    		return ;
    	}
    	for(int i = LOG; i >= 0; --i) {
    		int nxt = g[mid][i];
    		if(dep[nxt] <= dep[x]) continue;
    		int d1 = dis(con[x], nxt), d2 = dis(con[y], nxt);
    		if(d1 > d2 || (d1 == d2 && con[y] < con[x])) mid = nxt;
    	}
    	f[con[x]] += sz[ny] - sz[mid];
    	f[con[y]] += sz[mid] - sz[y];
    }
    
    bool cmp(int x, int y) {
    	return dfn[x] < dfn[y];
    }
    
    void go() {
    	now = cnt = top = 0;
    	int num;
    	scanf("%d", &num);
    	for(int i = 1; i <= num; ++i) {
    		scanf("%d", &be[i]);
    		con[be[i]] = be[i], af[i] = be[i];
    	}
    	sort(af + 1, af + 1 + num, cmp);
    	if(con[1] != 1) st[++top] = 1;
    	for(int i = 1; i <= num; ++i) {
    		int pos = af[i], fa = 0;
    		for(; top;) {
    			fa = GetLCA(st[top], pos);
    			if(top > 1 && dep[fa] < dep[st[top - 1]]) {
    				add(st[top - 1], st[top]), top--;
    			} else if(dep[fa] < dep[st[top]]) {
    				add(fa, st[top]), top--;
    				break;
    			} else {
    				break;
    			}
    		}
    		if(st[top] != fa) 
    			st[++top] = fa;
    		st[++top] = pos;
    	}
    	for(; top > 1; --top)
    		add(st[top - 1], st[top]);
    	dfs(1), Dfs(1);
    	for(int i = 1; i <= now; ++i)
    		for(int j = head[c[i]]; j; j = a[j].nt)
    			work(c[i], a[j].to);
    	for(int i = 1; i <= now; ++i)
    		f[con[c[i]]] += rem[c[i]];
    	for(int i = 1; i <= num; ++i) 
    		printf("%d ", f[be[i]]);
    	puts("");
    	for(int i = 1; i <= now; ++i)
    		head[c[i]] = con[c[i]] = f[c[i]] = rem[c[i]] = 0;
    }
    
    int main() {
    	scanf("%d", &n);
    	for(int i = 1, x, y; i < n; ++i) {
    		scanf("%d%d", &x, &y);
    		add(x, y), add(y, x);
    	}
    	predfs(1, 0), prepare();
    	memset(head, 0, sizeof head);
    	int m;
    	scanf("%d", &m);
    	for(int o = 1; o <= m; ++o) {
    		go();
    	}
    	return 0;
    }
    

你可能感兴趣的:(虚树)