【ZJU】3863 Paths on the Tree 【浙大2015年4月校赛D题】 树分治

传送门:【ZJU】3863 Paths on the Tree


题意:给一棵树,问树上有多少个路径对有不超过K个公共节点的,路径a->b和b->a等价,路径对(A,B)和(B,A)只有当A和B是同一条路径时相同。


分析:反过来考虑,考虑有超过K+1个公共节点的路径对数。我们考虑重叠的路径部分,这个可以用树分治来搞,然后路径对的两端延伸出去的部分不重叠,我们要预处理出这个部分。最后就是当前枚举的子树和之前子树乘一乘。


trick:路径对总数是会爆long long的,要用unsigned long long。


后记:太懒了。。写不动题解啊,粗略的写了下思路,细节可以自己想想,还不理解可以QQ找我。。


代码如下:

#include <stdio.h>
#include <string.h>
#include <string>
#include <math.h>
#include <map>
#include <algorithm>
using namespace std ;

typedef long long LL ;
typedef unsigned long long ULL ;

#define rep( i , a , b ) for ( int i = ( a ) ; i <  ( b ) ; ++ i )
#define For( i , a , b ) for ( int i = ( a ) ; i <= ( b ) ; ++ i )
#define rev( i , a , b ) for ( int i = ( a ) ; i >= ( b ) ; -- i )
#define clr( a , x ) memset ( a , x , sizeof a )

const int MAXN = 100005 ;
const int MAXE = 400005 ;
struct Edge {
	int v , c , n ;
	Edge () {}
	Edge ( int v , int n ) : v ( v ) , n ( n ) {}
} ;

struct Node {
	int d ;
	ULL x ;
	Node () {}
	Node ( int d , ULL x ) : d ( d ) , x ( x ) {}
} ;

Edge E[MAXE] ;
int H[MAXN] , cntE ;
int vis[MAXN] , Time ;
int siz[MAXN] ;
int pre[MAXN] ;
int dep[MAXN] ;
int max_dis , tree_size ;
int Q[MAXN] , head , tail ;
Node S[MAXN] ;
int top ;
ULL c[MAXN] , cc[MAXN] ;
int n , k ;
ULL ans , tmpc0 ;

void clear () {
	ans = 0 ;
	cntE = 0 ;
	++ Time ;
	clr ( H , -1 ) ;
}

void addedge ( int u , int v ) {
	E[cntE] = Edge ( v , H[u] ) ;
	H[u] = cntE ++ ;
}

int get_root ( int s ) {
	head = tail = 0 ;
	Q[tail ++] = s ;
	dep[s] = 0 ;
	pre[s] = 0 ;
	while ( head != tail ) {
		int u = Q[head ++] ;
		for ( int i = H[u] ; ~i ; i = E[i].n ) {
			int v = E[i].v ;
			if ( v == pre[u] || vis[v] == Time ) continue ;
			pre[v] = u ;
			dep[v] = dep[u] + 1 ;
			Q[tail ++] = v ;
		}
	}
	max_dis = dep[Q[tail - 1]] ;
	int root = s , root_siz = MAXN , max_siz = tail ;
	tree_size = tail ;
	while ( head ) {
		int u = Q[-- head] ;
		int cnt = 0 ;
		siz[u] = 1 ;
		for ( int i = H[u] ; ~i ; i = E[i].n ) {
			int v = E[i].v ;
			if ( v == pre[u] || vis[v] == Time ) continue ;
			siz[u] += siz[v] ;
			if ( siz[v] > cnt ) cnt = siz[v] ;
		}
		cnt = max ( cnt , max_siz - siz[u] ) ;
		if ( cnt < root_siz ) {
			root_siz = cnt ;
			root = u ;
		}
	}
	return root ;
}

void calc ( int s , int s_size , int root ) {
	head = tail = 0 ;
	Q[tail ++] = s ;
	dep[s] = 1 ;
	pre[s] = root ;
	top = 0 ;
	ULL c0 = tmpc0 - ( ULL ) 2 * s_size * ( n - s_size ) ;
	while ( head != tail ) {
		int u = Q[head ++] ;
		ULL x = 1 , y = 0 ;
		for ( int i = H[u] ; ~i ; i = E[i].n ) {
			int v = E[i].v , c = E[i].c ;
			if ( v == pre[u] ) continue ;
			x += E[i].c ;
			y += ( ULL ) E[i].c * E[i].c ;
		}
		//printf ( "%I64u %I64u %d %d\n" , x , y , dep[u] , u ) ;
		S[top ++] = Node ( dep[u] , x * x - y ) ;
		for ( int i = H[u] ; ~i ; i = E[i].n ) {
			int v = E[i].v ;
			if ( v == pre[u] || vis[v] == Time ) continue ;
			pre[v] = u ;
			dep[v] = dep[u] + 1 ;
			Q[tail ++] = v ;
		}
	}
	rep ( i , 0 , top ) {
		int idx = max ( 0 , k - S[i].d - 1 ) ;
		if ( idx > max_dis ) continue ;
		if ( !idx ) ans += c0 * S[i].x ;
		ans += cc[idx] * S[i].x ;
	}
	rep ( i , 0 , top ) c[S[i].d] += S[i].x ;
	cc[max_dis] = c[max_dis] ;
	rev ( i , max_dis , 1 ) cc[i - 1] = cc[i] + c[i - 1] ;
}

void dfs ( int u ) {
	int root = get_root ( u ) ;
	if ( tree_size < k ) return ;//no satisfied path
	vis[root] = Time ;
	memset ( c , 0 , sizeof ( c[0] ) * ( max_dis + 1 ) ) ;
	memset ( cc , 0 , sizeof ( cc[0] ) * ( max_dis + 1 ) ) ;
	//For ( i , 0 , max_dis ) c[i] = cc[i] = 0 ;
	ULL y = 0 ;
	for ( int i = H[root] ; ~i ; i = E[i].n ) y += ( ULL ) E[i].c * E[i].c ;
	tmpc0 = ( ULL ) n * n - y ;
	for ( int i = H[root] ; ~i ; i = E[i].n ) {
		int v = E[i].v ;
		if( vis[v] == Time ) continue ;
		calc ( v , E[i].c , root ) ;
	}
	for ( int i = H[root] ; ~i ; i = E[i].n ) if ( vis[E[i].v] != Time ) dfs ( E[i].v ) ;
}

void pre_dfs ( int u , int f ) {
	siz[u] = 1 ;
	for ( int i = H[u] ; ~i ; i = E[i].n ) {
		int v = E[i].v ;
		if ( v == f ) continue ;
		pre_dfs ( v , u ) ;
		E[i].c = siz[v] ;
		siz[u] += siz[v] ;
	}
	for ( int i = H[u] ; ~i ; i = E[i].n ) if ( E[i].v == f ) {
		E[i].c = n - siz[u] ;
		break ;
	}
}

void solve () {
	int u , v ;
	clear () ;
	scanf ( "%d%d" , &n , &k ) ;
	++ k ;
	rep ( i , 1 , n ) {
		scanf ( "%d%d" , &u , &v ) ;
		addedge ( u , v ) ;
		addedge ( v , u ) ;
	}
	ULL x = ( ULL ) n * ( n + 1 ) / 2 ;
	pre_dfs ( 1 , 0 ) ;
	dfs ( 1 ) ;
	printf ( "%llu\n" , x * x - ans ) ;
}

int main () {
	int T ;
	Time = 0 ;
	clr ( vis , 0 ) ;
	scanf ( "%d" , &T ) ;
	For ( i , 1 , T ) solve () ;
	return 0 ;
}


你可能感兴趣的:(ZJU)