【Codeforces】Codeforces Round #299 (Div. 1) E. Tavas on the Path 【树链剖分+区间合并】

传送门:【Codeforces】Codeforces Round #299 (Div. 1) E. Tavas on the Path


大概题意:

首先对于一个串s,我们可以提取m个只包含1的块,每个块都是s里的一个最长连续1子串。

然后我们设x1,x2,x3...xm分别为这m个块中1的长度。

现在我们有函数T(S),定义 ,f(xi)是f关于xi的函数,下面会给出。

现在我们有一棵树,N个节点N-1条边,每条边一个边权。然后我们有N-1个f(x)的值,分别对应不同的长度。

然后我们有Q次询问,每次询问包含两个点u,v和一个权值w,然后我们用这条路径上所有的边来构造一个01串b,其中边权大于等于w的赋值为1,否则赋值为0(比如询问是u=1,v=6,w=2,5条边分别是(1,2,1),(2,3,2),(3,4,1),(4,5,2),(5,6,2),那么构成的01串b就是01011)。然后我们就要输出T(b)的值,也就是f(1)+f(2)。

题目要求就是,给一棵N节点的树,N-1个f函数,N-1条边,Q个询问,每次询问输出T(b)的值。


题目分析:

这个如果看懂题意,并且会树链剖分以及区间合并操作的话,就没有问题了。。

首先!我们先对边以及操作离个线,然后按照权值排序(升序降序随意,我用的升序),把所有的边都设为1然后插入到线段树中(即一开始的路径都是全1串),然后按顺序枚举询问,将权值小于询问的边对应的值置为0。这样我们就可以离线处理问题啦~

接下来大致讲一下树链剖分中的区间合并怎么操作。

我们假设路径是有向的,那么我们维护左链,维护左链的右端点Lnum(Lnum=1当且仅当左链的右端点是1),以及从左链的右端往左的连续1个数Llen,对于右链我们同理,维护Rnum,Rlen。我们再维护一个答案ans,表示这次查询的结果。这些变量初始值都设为0。

线段树内的查询,我们可以维护7个值(为了方便。。),区间左端点lnum,右端点rnum,左端点向右的连续1个数llen,右端点向左的连续1个数rlen,区间内部分块的和sum,区间范围[l,r]。对于区间合并,我们设左区间为tmpl,右区间为tmpr,要合并成的区间为tmp,那么首先tmp.sum = tmpl.sum + tmpr.sum,毋庸置疑。然后什么时候需要改变sum?自然是tmpl.rnum和tmpr.lnum都等于1的时候了,这时候tmp.sum += f[tmpl.rlen + tmpr.llen] - f[tmpl.rlen] - f[tmpr.llen]。其他值的转移的应该就不用我细说了。处理完以后这个区间就返回一个tmp即可。这样线段树内的维护就没问题了。

然后对于路径上的所有子区间,我们发现这些子区间是从下到上依次访问的,所以我们挨个合并,并且分两路合并,对于左链上的子区间tmp,维护方式和线段树内的维护尤为相似(本就是同一种操作嘛~),就是看Lnum是不是和tmp.rnum一样都为1(我们假设子区间的llen靠近根,而rlen靠近叶子),是的话我们就维护ans,ans += f[Llen + tmp.rlen] - f[Llen] - f[tmp.rlen]。其他操作不细说了。

对于右链和左链类似的处理就行了。注意各个变量不要搞错!

最后到了左链和右链汇合的时候了,如果中间还有个区间,三个依次合并,否则直接合并左右链。然后对于一个操作我们就做完了~

因为是离线的,所以我们用一个数组保存解,最后依次输出即可。


my code:

#include <stdio.h>
#include <string.h>
#include <algorithm>
using namespace std ;

typedef long long LL ;
typedef unsigned long long ULL ;

#define rep( i , a , b ) for ( int i = ( a ) ; i <  ( b ) ; ++ i )
#define For( i , a , b ) for ( int i = ( a ) ; i <= ( b ) ; ++ i )
#define rev( i , a , b ) for ( int i = ( a ) ; i >= ( b ) ; -- i )
#define clr( a , x ) memset ( a , x , sizeof a )
#define ls ( o << 1 )
#define rs ( o << 1 | 1 )
#define lson ls , l , m
#define rson rs , m + 1 , r
#define root 1 , 1 , n
#define rt o , l , r
#define mid ( ( l + r ) >> 1 )

const int MAXN = 100005 ;
const int MAXE = 200005 ;

struct Edge {
	int v , n ;
	Edge () {}
	Edge ( int v , int n ) : v ( v ) , n ( n ) {}
} ;

struct Node {
	int x , y , z , idx ;
	void input ( int i ) {
		scanf ( "%d%d%d" , &x , &y , &z ) ;
		idx = i ;
	}
	bool operator < ( const Node& a ) const {
		return z < a.z ;
	}
} ;

struct Query_Node {
	int lnum , rnum , llen , rlen , sum , l , r ;
	Query_Node () {}
	Query_Node ( int lnum , int rnum , int llen , int rlen , int sum , int l , int r ) :
		lnum ( lnum ) , rnum ( rnum ) , llen ( llen ) , rlen ( rlen ) , sum ( sum ) , l ( l ) , r ( r ) {}
} ;

Node e[MAXN] , q[MAXN] ;
Edge E[MAXE] ;
int H[MAXN] , cntE ;
int pre[MAXN] ;
int siz[MAXN] ;
int son[MAXN] ;
int dep[MAXN] ;
int pos[MAXN] ;
int idx[MAXN] ;
int top[MAXN] ;
int tree_idx ;
int n , m ;
int f[MAXN] ;
int sum[MAXN << 2] ;
int lnum[MAXN << 2] ;
int rnum[MAXN << 2] ;
int llen[MAXN << 2] ;
int rlen[MAXN << 2] ;
int res[MAXN] ;

void clear () {
	tree_idx = 0 ;
	dep[1] = 0 ;
	pre[1] = 0 ;
	siz[0] = 0 ;
	cntE = 0 ;
	clr ( H , -1 ) ;
}

void addedge ( int u , int v ) {
	E[cntE] = Edge ( v , H[u] ) ;
	H[u] = cntE ++ ;
}

void dfs ( int u ) {
	siz[u] = 1 ;
	son[u] = 0 ;
	for ( int i = H[u] ; ~i ; i = E[i].n ) {
		int v = E[i].v ;
		if ( v == pre[u] ) continue ;
		pre[v] = u ;
		dep[v] = dep[u] + 1 ;
		dfs ( v ) ;
		siz[u] += siz[v] ;
		if ( siz[son[u]] < siz[v] ) son[u] = v ;
	}
}

void rebuild ( int u , int top_element ) {
	top[u] = top_element ;
	pos[u] = ++ tree_idx ;
	if ( son[u] ) rebuild ( son[u] , top_element ) ;
	for ( int i = H[u] ; ~i ; i = E[i].n ) {
		int v = E[i].v ;
		if ( v != pre[u] && v != son[u] ) rebuild ( v , v ) ;
	}
}

void pushup ( int o , int l , int r ) {
	sum[o] = sum[ls] + sum[rs] ;
	llen[o] = llen[ls] ;
	rlen[o] = rlen[rs] ;
	lnum[o] = lnum[ls] ;
	rnum[o] = rnum[rs] ;
	if ( rnum[ls] && lnum[rs] ) {
		sum[o] -= f[rlen[ls]] + f[llen[rs]] ;
		sum[o] += f[rlen[ls] + llen[rs]] ;
	}
	int m = mid ;
	if ( llen[ls] == m - l + 1 ) llen[o] += llen[rs] ;
	if ( rlen[rs] == r - m ) rlen[o] += rlen[ls] ;
}

void update ( int x , int v , int o , int l , int r ) {
	if ( l == r ) {
		lnum[o] = rnum[o] = v ;
		llen[o] = rlen[o] = v ;
		if ( v ) sum[o] = f[1] ;
		else sum[o] = 0 ;
		return ;
	}
	int m = mid ;
	if ( x <= m ) update ( x , v , lson ) ;
	else update ( x , v , rson ) ;
	pushup ( rt ) ;
}

Query_Node query ( int L , int R , int o , int l , int r ) {
	if ( L <= l && r <= R ) return Query_Node ( lnum[o] , rnum[o] , llen[o] , rlen[o] , sum[o] , l , r ) ;
	int m = mid ;
	if ( R <= m ) return query ( L , R , lson ) ;
	if ( m <  L ) return query ( L , R , rson ) ;
	Query_Node tmpl = query ( L , R , lson ) ;
	Query_Node tmpr = query ( L , R , rson ) ;
	Query_Node tmp = Query_Node ( tmpl.lnum , tmpr.rnum , tmpl.llen , tmpr.rlen , tmpl.sum + tmpr.sum , tmpl.l , tmpr.r ) ;
	if ( tmpl.rnum && tmpr.lnum ) {
		tmp.sum -= f[tmpl.rlen] + f[tmpr.llen] ;
		tmp.sum += f[tmpl.rlen + tmpr.llen] ;
	}
	if ( tmpl.llen == tmpl.r - tmpl.l + 1 ) tmp.llen += tmpr.llen ;
	if ( tmpr.rlen == tmpr.r - tmpr.l + 1 ) tmp.rlen += tmpl.rlen ;
	return tmp ;
}

int Query ( int x , int y ) {
	int ans = 0 ;
	int Llen = 0 , Rlen = 0 , Lnum = 0 , Rnum = 0 ;
	while ( top[x] != top[y] ) {
		if ( dep[top[x]] > dep[top[y]] ) {
			Query_Node tmp = query ( pos[top[x]] , pos[x] , root ) ;
			ans += tmp.sum ;
			if ( tmp.rnum && Lnum ) {
				ans -= f[tmp.rlen] + f[Llen] ;
				ans += f[tmp.rlen + Llen] ;
			}
			if ( tmp.llen == tmp.r - tmp.l + 1 ) Llen += tmp.llen ;
			else Llen = tmp.llen ;
			Lnum = tmp.lnum ;
			x = pre[top[x]] ;
		} else {
			Query_Node tmp = query ( pos[top[y]] , pos[y] , root ) ;
			ans += tmp.sum ;
			if ( tmp.rnum && Rnum ) {
				ans -= f[tmp.rlen] + f[Rlen] ;
				ans += f[tmp.rlen + Rlen] ;
			}
			if ( tmp.llen == tmp.r - tmp.l + 1 ) Rlen += tmp.llen ;
			else Rlen = tmp.llen ;
			Rnum = tmp.lnum ;
			y = pre[top[y]] ;
		}
	}
	if ( dep[x] < dep[y] ) {
		Query_Node tmp = query ( pos[x] + 1 , pos[y] , root ) ;
		ans += tmp.sum ;
		if ( Lnum && tmp.lnum ) {
			ans -= f[Llen] + f[tmp.llen] ;
			ans += f[Llen + tmp.llen] ;
		}
		if ( tmp.rlen == tmp.r - tmp.l + 1 ) {
			Llen += tmp.rlen ;
			if ( Rnum ) {
				ans -= f[Llen] + f[Rlen] ;
				ans += f[Llen + Rlen] ;
			}
		} else {
			if ( tmp.rnum && Rnum ) {
				ans -= f[tmp.rlen] + f[Rlen] ;
				ans += f[tmp.rlen + Rlen] ;
			}
		}
	} else if ( dep[y] < dep[x] ) {
		Query_Node tmp = query ( pos[y] + 1 , pos[x] , root ) ;
		ans += tmp.sum ;
		if ( Lnum && tmp.rnum ) {
			ans -= f[Llen] + f[tmp.rlen] ;
			ans += f[Llen + tmp.rlen] ;
		}
		if ( tmp.llen == tmp.r - tmp.l + 1 ) {
			Llen += tmp.rlen ;
			if ( Rnum ) {
				ans -= f[Llen] + f[Rlen] ;
				ans += f[Llen + Rlen] ;
			}
		} else {
			if ( tmp.lnum && Rnum ) {
				ans -= f[tmp.llen] + f[Rlen] ;
				ans += f[tmp.llen + Rlen] ;
			}
		}
	} else {
		if ( Lnum && Rnum ) {
			ans -= f[Llen] + f[Rlen] ;
			ans += f[Llen + Rlen] ;
		}
	}
	return ans ;
}

void solve () {
	clear () ;
	rep ( i , 1 , n ) scanf ( "%d" , &f[i] ) ;
	rep ( i , 1 , n ) {
		e[i].input ( i ) ;
		addedge ( e[i].x , e[i].y ) ;
		addedge ( e[i].y , e[i].x ) ;
	}
	rep ( i , 0 , m ) q[i].input ( i ) ;
	dfs ( 1 ) ;
	rebuild ( 1 , 1 ) ;
	sort ( e + 1 , e + n ) ;
	sort ( q + 0 , q + m ) ;
	clr ( sum , 0 ) ;
	clr ( lnum , 0 ) ;
	clr ( rnum , 0 ) ;
	clr ( llen , 0 ) ;
	clr ( rlen , 0 ) ;
	rep ( i , 1 , n ) {
		int u = e[i].x , v = e[i].y ;
		if ( dep[u] > dep[v] ) idx[i] = u ;
		else idx[i] = v ;
		update ( pos[idx[i]] , 1 , root ) ;
	}
	int j = 1 ;
	rep ( i , 0 , m ) {
		while ( j < n && e[j].z < q[i].z ) {
			update ( pos[idx[j]] , 0 , root ) ;
			++ j ;
		}
		res[q[i].idx] = Query ( q[i].x , q[i].y ) ;
	}
	rep ( i , 0 , m ) printf ( "%d\n" , res[i] ) ;
}

int main () {
	while ( ~scanf ( "%d%d" , &n , &m ) ) solve () ;
	return 0 ;
}


你可能感兴趣的:(codeforces)