【codeforces】494D. Birthday 【树型DP+离线tarjan求LCA】

传送门:【codeforces】494D. Birthday


题目分析:首先明确一点,平方和是可以递推的!

1.我们第一次dfs,求出从一个点到其子树的所有点的距离的平方和。

大体是保存三个变量来递推,设u的子树上点的个数cnt,到子树上点的距离和sum,到子树上点的距离平方和sqr,每次dfs完一个儿子v以后:

dp[u].cnt+=dp[v].cnt

dp[u].sum+=dp[v].sum+dp[v].cnt*dist(u,v)

dp[u].sqr+=dp[v].sqr+2*dp[v].sum*dist(u,v)+dp[v].cnt*dist(u,v)*dist(u,v)

sqr的推导是根据sigma{ (xi+dist(u,v))^2 }(xi是v到其子树编号为i的点的距离),拆开来就发现了。


2.然后我们再做第二次dfs求出每个点到所有点的距离的平方和。

tot[u].cnt表示以u为根的子树的节点数(实际上始终等于n)

tot[u].sum表示u到所有点的距离的和

tot[u].sqr表示u到所有点的距离的平方和


3.将询问离线求出每对询问的lca,这样如果我们得到了所有的dist(1,i)便可以根据lca得到两点间的距离。


4.接下来就是分类讨论。

(1)lca(u,v)!=u&&lca(u,v)!=v

令tmp = u到v的子树中所有点的距离的平方和。我们根据lca可以得出u到v的距离dist(u,v),然后依旧是平方和展开,求得tmp,这里就不细说了。

此时答案就是2*tmp-tot[u].sqr。

(2)lca(u,v)==u&&lca(u,v)!=v

仔细观察可以发现,其实这种情况和(1)等价。

此时答案就是2*tmp-tot[u].sqr。

(3)lca(u,v)=v

此时先对v将v的子树对v的贡献从tot[v]中去掉,然后令tmp = u到v除其子树外所有点的平方和,dist(u,v)已知,所以平方和展开后同样可以得到解了。

此时答案就是tot[u].sqr-2*tmp。


至此问题完满解决。

PS:本题解法挺简单的,就是转移如果不太熟悉的话可能要搞一会了(我就是搞了好久的转移。。。主要是第二次dfs考虑了很久才想清楚)


代码如下:


#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std ;

typedef long long LL ;

#define rep( i , a , b ) for ( int i = ( a ) ; i <  ( b ) ; ++ i )
#define For( i , a , b ) for ( int i = ( a ) ; i <= ( b ) ; ++ i )
#define rev( i , a , b ) for ( int i = ( a ) ; i >= ( b ) ; -- i )
#define clr( a , x ) memset ( a , x , sizeof a )
#define ls ( o << 1 )
#define rs ( o << 1 | 1 )
#define lson ls , l , m
#define rson rs , m + 1 , r
#define mid ( ( l + r ) >> 1 )

const int MAXN = 100005 ;
const int MAXE = 400005 ;
const int mod = 1e9 + 7 ;

struct Edge {
	int v , c , n ;
	Edge () {}
	Edge ( int v , int c , int n ) : v ( v ) , c ( c ) , n ( n ) {}
} ;

struct Node {
	int cnt , sum , sqr ;
	Node () {}
	Node ( int cnt , int sum , int sqr ) : cnt ( cnt ) , sum ( sum ) , sqr ( sqr ) {}
} ;

struct Query {
	int u , v , lca ;
	Query () {}
	Query ( int u , int v ) : u ( u ) , v ( v ) {}
} ;

Edge E[MAXE] ;
int H[MAXN] , Q[MAXN] , cntE ;
Node dp[MAXN] , tot[MAXN] ;
Query qu[MAXN] ;
int pre[MAXN] ;
int dis[MAXN] ;
int vis[MAXN] ;
int p[MAXN] ;
int n , q ;

void clear () {
	cntE = 0 ;
	pre[1] = 0 ;
	dis[1] = 0 ;
	clr ( H , -1 ) ;
	clr ( Q , -1 ) ;
	clr ( vis , 0 ) ;
	rep ( i , 0 , MAXN ) p[i] = i ;
}

void addedge ( int u , int v , int c , int H[] ) {
	E[cntE] = Edge ( v , c , H[u] ) ;
	H[u] = cntE ++ ;
}

int find ( int x ) {
	return p[x] == x ? x : ( p[x] = find ( p[x] ) ) ;
}

void dfs1 ( int u ) {
	dp[u] = Node ( 1 , 0 , 0 ) ;
	for ( int i = H[u] ; ~i ; i = E[i].n ) {
		int v = E[i].v , c = E[i].c ;
		if ( v == pre[u] ) continue ;
		pre[v] = u ;
		dis[v] = ( dis[u] + c ) % mod ;
		dfs1 ( v ) ;
		dp[u].cnt = ( dp[u].cnt + dp[v].cnt ) % mod ;
		dp[u].sum = ( dp[u].sum + dp[v].sum + 1LL * dp[v].cnt * c ) % mod ;
		dp[u].sqr = ( dp[u].sqr + dp[v].sqr + 2LL * dp[v].sum * c + 1LL * dp[v].cnt * c % mod * c ) % mod ;
	}
}

void dfs2 ( int u ) {
	for ( int i = H[u] ; ~i ; i = E[i].n ) {
		int v = E[i].v , c = E[i].c ;
		if ( v == pre[u] ) continue ;
		Node tmp ;
		tmp.cnt = ( ( tot[u].cnt - dp[v].cnt ) % mod + mod ) % mod ;
		tmp.sum = ( ( tot[u].sum - dp[v].sum - 1LL * c * dp[v].cnt ) % mod + mod ) % mod ;
		tmp.sqr = ( ( tot[u].sqr - dp[v].sqr - 2LL * dp[v].sum * c - 1LL * dp[v].cnt * c % mod * c ) % mod + mod ) % mod ;
		tot[v].cnt = ( dp[v].cnt + tmp.cnt ) % mod ;
		tot[v].sum = ( dp[v].sum + tmp.sum + 1LL * c * tmp.cnt ) % mod ;
		tot[v].sqr = ( dp[v].sqr + tmp.sqr + 2LL * tmp.sum * c + 1LL * tmp.cnt * c % mod * c ) % mod ;
		dfs2 ( v ) ;
	}
}

void dfs3 ( int u ) {
	p[u] = u ;
	vis[u] = 1 ;
	for ( int i = H[u] ; ~i ; i = E[i].n ) {
		int v = E[i].v ;
		if ( v == pre[u] ) continue ;
		dfs3 ( v ) ;
		p[v] = u ;
	}
	for ( int i = Q[u] ; ~i ; i = E[i].n ) {
		int v = E[i].v ;
		if ( vis[v] ) qu[E[i].c].lca = find ( v ) ;
	}
}

void solve () {
	int u , v , c ;
	clear () ;
	rep ( i , 1 , n ) {
		scanf ( "%d%d%d" , &u , &v , &c ) ;
		addedge ( u , v , c , H ) ;
		addedge ( v , u , c , H ) ;
	}
	scanf ( "%d" , &q ) ;
	rep ( i , 0 , q ) {
		scanf ( "%d%d" , &u , &v ) ;
		qu[i] = Query ( u , v ) ;
		addedge ( u , v , i , Q ) ;
		addedge ( v , u , i , Q ) ;
	}
	dfs1 ( 1 ) ;
	tot[1] = dp[1] ;
	dfs2 ( 1 ) ;
	dfs3 ( 1 ) ;
	//For ( i , 1 , n ) printf ( "%d = %d\n" , i , dp[i].sqr ) ;
	//For ( i , 1 , n ) printf ( "%d = %d\n" , i , tot[i].sqr ) ;
	rep ( i , 0 , q ) {
		u = qu[i].u ;
		v = qu[i].v ;
		int f = qu[i].lca ;
		int d = ( ( dis[u] + dis[v] - 2 * dis[f] ) % mod + mod ) % mod ;
		if ( f != v ) {
			int tmp = ( dp[v].sqr + 2LL * dp[v].sum * d + 1LL * dp[v].cnt * d % mod * d ) % mod ;
			printf ( "%d\n" , ( ( 2LL * tmp - tot[u].sqr ) % mod + mod ) % mod ) ;
		} else {
			Node t ;
			t.cnt = ( ( tot[v].cnt - dp[v].cnt ) % mod + mod ) % mod ;
			t.sum = ( ( tot[v].sum - dp[v].sum ) % mod + mod ) % mod ;
			t.sqr = ( ( tot[v].sqr - dp[v].sqr ) % mod + mod ) % mod ;
			int tmp = ( t.sqr + 2LL * d * t.sum + 1LL * t.cnt * d % mod * d ) % mod ;
			printf ( "%d\n" , ( ( tot[u].sqr - 2LL * tmp ) % mod + mod ) % mod ) ;
		}
	}
}

int main () {
	while ( ~scanf ( "%d" , &n ) ) solve () ;
	return 0 ;
}


你可能感兴趣的:(codeforces)