NOIP模拟赛 T2树

题目描述

有一棵树,节点编号为 1 1 1 n n n。还有 m m m次查询,每次给出 l , r , x l,r,x l,r,x,你需要回答从 x x x号点走到编号在 l l l r r r之间的点的最小距离。

时限2s,空间256MB。

输入格式

第一行一个正整数 n n n

接下来 n − 1 n-1 n1行,每行三个正整数 u , v , l u,v,l u,v,l,表示 u u u v v v之间有一条长度为 l l l的边。

接下来一行一个正整数 m m m

接下来 m m m行,每行三个正整数 l , r , x l,r,x l,r,x,表示一次询问。

输出格式

输出 m m m行,每行一个整数,表示答案。

输入样例

3
1 2 1
1 3 1
3
2 3 1
2 3 2
3 3 2

输出样例

1
0
2

数据范围

n , m ≤ 1 0 5 n,m\leq 10^5 n,m105,任意两点之间距离不超过 1 0 9 10^9 109


题解

在读完所有边之后,我们可以做一次dfs。但这个dfs不是一般的dfs,它遍历的方式不同。

dfs需要记录当前的点 u u u和深度dep。假设当前从边 e e e连到点 u u u,然后对 u u u进行遍历。这时,我们把所有从 u u u开始不经过边 e e e可以遍历到的点看作一棵树。在这个连通块中,我们需要找到一个 x x x,满足如果在这棵树中如果以 x x x为根,那么 x x x的各个子树的大小都不超过这棵树的大小的一半。然后,以 x x x为根,将这棵树遍历一次,每个点 i i i的第 d e p dep dep层祖先都是 x x x,记为 t o [ d e p ] [ i ] to[dep][i] to[dep][i]。求出 x x x到这棵树中各个点 i i i的距离 d i s [ d e p ] [ i ] dis[dep][i] dis[dep][i]。用线段树来存第 d e p dep dep层的这个连通块中每个点到 x x x的距离,这棵线段树的根为 t r [ d e p ] [ x ] tr[dep][x] tr[dep][x]

最后,将 x x x标记为已经过的点。以 x x x为根,遍历各个子树,继续dfs。这样相当于对当前点 u u u的每个子树,将其满足条件的点 x x x当作这棵子树的根节点,然后继续遍历。点 u u u连向子树的边的另一端不一定是子树的根节点,但一定是这颗子树的一个点。

最后,对于每一组询问 l , r , x l,r,x l,r,x,枚举层数 i i i,这一层的答案为 x x x到第 i i i层祖先的距离 d i s [ i ] [ x ] dis[i][x] dis[i][x]加上第 i i i层祖先的线段树中到编号在 l l l r r r之间的点的最短距离。求各层的最小值即为答案。

因为对树进行了重构,所以每向下一个点,子树的大小就至少减少一半,也就是最多只有 log ⁡ n \log n logn层。对于每个节点,都将其子树遍历了一次。也就是每个节点都被其各个父亲遍历了一次,总共被遍历 n log ⁡ n n\log n nlogn次。每次遍历都被放在线段树中放一次需要 O ( log ⁡ n ) O(\log n) O(logn),所以dfs的总时间复杂度为 O ( n log ⁡ 2 n ) O(n\log^2 n) O(nlog2n)

对于查询,每次查询都枚举了层数,枚举了 log ⁡ n \log n logn次,每次都要线段树查询,时间复杂度为 O ( log ⁡ n ) O(\log n) O(logn)。每次查询的时间复杂度为 O ( log ⁡ 2 n ) O(\log^2 n) O(log2n),查询的总时间复杂度为 O ( m log ⁡ 2 n ) O(m\log^2 n) O(mlog2n)

所以,总时间复杂度为 O ( ( n + m ) log ⁡ 2 n ) O((n+m)\log^2 n) O((n+m)log2n)

对于空间,每个点会被放入线段树中 log ⁡ n \log n logn,总共最多会开 n log ⁡ 2 n n\log^2 n nlog2n个点。用动态开点即可。

这道题有一定的思维难度,可以结合代码帮助理解。

code

#include
#define lc tr[k].l
#define rc tr[k].r
using namespace std;
const int V=17,N=100000;
int n,m,tot=0,cnt=0,cl=0,d[N*2+5],l[N*2+5],r[N*2+5],w[N*2+5];
int s1,s[N+5],z[N+5],c[N+5],fa[N+5],siz[N+5];
int ans,dis[V][N+5],to[V][N+5],rt[V][N+5];
struct node{
	int l,r,s;
}tr[17000005];
queue<int>q;
void add(int xx,int yy,int zz){
	l[++tot]=r[xx];d[tot]=yy;r[xx]=tot;w[tot]=zz;
}
void ch(int &k,int l,int r,int x,int v){
	if(!k) k=++cnt;
	if(l==r&&l==x){
		tr[k].s=v;
		return;
	}
	int mid=l+r>>1;
	if(x<=mid) ch(lc,l,mid,x,v);
	else ch(rc,mid+1,r,x,v);
	tr[k].s=min(tr[lc].s,tr[rc].s);
}
int find(int k,int l,int r,int x,int y){
	if(!k) return 1e9+5;
	if(l>=x&&r<=y) return tr[k].s;
	int mid=l+r>>1;
	int re=1e9+5;
	if(x<=mid) re=min(re,find(lc,l,mid,x,y));
	if(y>mid) re=min(re,find(rc,mid+1,r,x,y));
	return re;
}
int gtfa(int i){
	if(c[i]!=cl){
		c[i]=cl;fa[i]=0;
	}
	return fa[i];
}
void dfs(int u,int dep){
	q.push(u);
	s1=1;s[1]=u;
	++cl;
	while(!q.empty()){
		int i=q.front();q.pop();
		for(int j=r[i];j;j=l[j]){
			if(gtfa(i)!=d[j]&&!z[d[j]]){
				c[d[j]]=cl;
				fa[d[j]]=i;
				q.push(d[j]);
				s[++s1]=d[j];
			}
		}
	}
	for(int i=s1;i>=1;i--){
		siz[s[i]]=1;
		for(int j=r[s[i]];j;j=l[j]){
			if(gtfa(s[i])!=d[j]&&!z[d[j]]){
				siz[s[i]]+=siz[d[j]];
			}
		}
	}
	if(siz[u]==1){
		to[dep][u]=u;
		ch(rt[dep][u],1,n,u,0);
		return;
	}
	int x,mx;
	for(int i=1;i<=s1;i++){
		mx=siz[u]-siz[s[i]];
		for(int j=r[s[i]];j;j=l[j]){
			if(gtfa(s[i])!=d[j]&&!z[d[j]]){
				mx=max(mx,siz[d[j]]);
			}
		}
		if(mx<=siz[u]/2){
			x=s[i];break;
		}
	}
	q.push(x);
	++cl;
	to[dep][x]=x;
	dis[dep][x]=0;
	ch(rt[dep][x],1,n,x,0);
	while(!q.empty()){
		int i=q.front();q.pop();
		for(int j=r[i];j;j=l[j]){
			if(gtfa(i)!=d[j]&&!z[d[j]]){
				c[d[j]]=cl;
				fa[d[j]]=i;
				to[dep][d[j]]=x;
				dis[dep][d[j]]=dis[dep][i]+w[j];
				ch(rt[dep][x],1,n,d[j],dis[dep][d[j]]);
				q.push(d[j]);
			}
		}
	}
	z[x]=1;
	for(int i=r[x];i;i=l[i]){
		if(!z[d[i]]) dfs(d[i],dep+1);
	}
}
int main()
{
	scanf("%d",&n);
	for(int i=1,x,y,z;i<n;i++){
		scanf("%d%d%d",&x,&y,&z);
		add(x,y,z);add(y,x,z);
	}
	tr[0].s=1e9+5;
	dfs(1,0);
	scanf("%d",&m);
	for(int o=1,l,r,x;o<=m;o++){
		scanf("%d%d%d",&l,&r,&x);
		ans=1e9+5;
		for(int i=0;i<17;i++){
			if(to[i][x]){
				ans=min(ans,dis[i][x]+find(rt[i][to[i][x]],1,n,l,r));
			}
		}
		printf("%d\n",ans);
	}
	return 0;
}

你可能感兴趣的:(题解,c++,题解)