传送门:【codeforces】494D. Birthday
题目分析:首先明确一点,平方和是可以递推的!
1.我们第一次dfs,求出从一个点到其子树的所有点的距离的平方和。
大体是保存三个变量来递推,设u的子树上点的个数cnt,到子树上点的距离和sum,到子树上点的距离平方和sqr,每次dfs完一个儿子v以后:
dp[u].cnt+=dp[v].cnt
dp[u].sum+=dp[v].sum+dp[v].cnt*dist(u,v)
dp[u].sqr+=dp[v].sqr+2*dp[v].sum*dist(u,v)+dp[v].cnt*dist(u,v)*dist(u,v)
sqr的推导是根据sigma{ (xi+dist(u,v))^2 }(xi是v到其子树编号为i的点的距离),拆开来就发现了。
2.然后我们再做第二次dfs求出每个点到所有点的距离的平方和。
tot[u].cnt表示以u为根的子树的节点数(实际上始终等于n)
tot[u].sum表示u到所有点的距离的和
tot[u].sqr表示u到所有点的距离的平方和
3.将询问离线求出每对询问的lca,这样如果我们得到了所有的dist(1,i)便可以根据lca得到两点间的距离。
4.接下来就是分类讨论。
(1)lca(u,v)!=u&&lca(u,v)!=v
令tmp = u到v的子树中所有点的距离的平方和。我们根据lca可以得出u到v的距离dist(u,v),然后依旧是平方和展开,求得tmp,这里就不细说了。
此时答案就是2*tmp-tot[u].sqr。
(2)lca(u,v)==u&&lca(u,v)!=v
仔细观察可以发现,其实这种情况和(1)等价。
此时答案就是2*tmp-tot[u].sqr。
(3)lca(u,v)=v
此时先对v将v的子树对v的贡献从tot[v]中去掉,然后令tmp = u到v除其子树外所有点的平方和,dist(u,v)已知,所以平方和展开后同样可以得到解了。
此时答案就是tot[u].sqr-2*tmp。
至此问题完满解决。
PS:本题解法挺简单的,就是转移如果不太熟悉的话可能要搞一会了(我就是搞了好久的转移。。。主要是第二次dfs考虑了很久才想清楚)
代码如下:
#include <cstdio> #include <cstring> #include <algorithm> using namespace std ; typedef long long LL ; #define rep( i , a , b ) for ( int i = ( a ) ; i < ( b ) ; ++ i ) #define For( i , a , b ) for ( int i = ( a ) ; i <= ( b ) ; ++ i ) #define rev( i , a , b ) for ( int i = ( a ) ; i >= ( b ) ; -- i ) #define clr( a , x ) memset ( a , x , sizeof a ) #define ls ( o << 1 ) #define rs ( o << 1 | 1 ) #define lson ls , l , m #define rson rs , m + 1 , r #define mid ( ( l + r ) >> 1 ) const int MAXN = 100005 ; const int MAXE = 400005 ; const int mod = 1e9 + 7 ; struct Edge { int v , c , n ; Edge () {} Edge ( int v , int c , int n ) : v ( v ) , c ( c ) , n ( n ) {} } ; struct Node { int cnt , sum , sqr ; Node () {} Node ( int cnt , int sum , int sqr ) : cnt ( cnt ) , sum ( sum ) , sqr ( sqr ) {} } ; struct Query { int u , v , lca ; Query () {} Query ( int u , int v ) : u ( u ) , v ( v ) {} } ; Edge E[MAXE] ; int H[MAXN] , Q[MAXN] , cntE ; Node dp[MAXN] , tot[MAXN] ; Query qu[MAXN] ; int pre[MAXN] ; int dis[MAXN] ; int vis[MAXN] ; int p[MAXN] ; int n , q ; void clear () { cntE = 0 ; pre[1] = 0 ; dis[1] = 0 ; clr ( H , -1 ) ; clr ( Q , -1 ) ; clr ( vis , 0 ) ; rep ( i , 0 , MAXN ) p[i] = i ; } void addedge ( int u , int v , int c , int H[] ) { E[cntE] = Edge ( v , c , H[u] ) ; H[u] = cntE ++ ; } int find ( int x ) { return p[x] == x ? x : ( p[x] = find ( p[x] ) ) ; } void dfs1 ( int u ) { dp[u] = Node ( 1 , 0 , 0 ) ; for ( int i = H[u] ; ~i ; i = E[i].n ) { int v = E[i].v , c = E[i].c ; if ( v == pre[u] ) continue ; pre[v] = u ; dis[v] = ( dis[u] + c ) % mod ; dfs1 ( v ) ; dp[u].cnt = ( dp[u].cnt + dp[v].cnt ) % mod ; dp[u].sum = ( dp[u].sum + dp[v].sum + 1LL * dp[v].cnt * c ) % mod ; dp[u].sqr = ( dp[u].sqr + dp[v].sqr + 2LL * dp[v].sum * c + 1LL * dp[v].cnt * c % mod * c ) % mod ; } } void dfs2 ( int u ) { for ( int i = H[u] ; ~i ; i = E[i].n ) { int v = E[i].v , c = E[i].c ; if ( v == pre[u] ) continue ; Node tmp ; tmp.cnt = ( ( tot[u].cnt - dp[v].cnt ) % mod + mod ) % mod ; tmp.sum = ( ( tot[u].sum - dp[v].sum - 1LL * c * dp[v].cnt ) % mod + mod ) % mod ; tmp.sqr = ( ( tot[u].sqr - dp[v].sqr - 2LL * dp[v].sum * c - 1LL * dp[v].cnt * c % mod * c ) % mod + mod ) % mod ; tot[v].cnt = ( dp[v].cnt + tmp.cnt ) % mod ; tot[v].sum = ( dp[v].sum + tmp.sum + 1LL * c * tmp.cnt ) % mod ; tot[v].sqr = ( dp[v].sqr + tmp.sqr + 2LL * tmp.sum * c + 1LL * tmp.cnt * c % mod * c ) % mod ; dfs2 ( v ) ; } } void dfs3 ( int u ) { p[u] = u ; vis[u] = 1 ; for ( int i = H[u] ; ~i ; i = E[i].n ) { int v = E[i].v ; if ( v == pre[u] ) continue ; dfs3 ( v ) ; p[v] = u ; } for ( int i = Q[u] ; ~i ; i = E[i].n ) { int v = E[i].v ; if ( vis[v] ) qu[E[i].c].lca = find ( v ) ; } } void solve () { int u , v , c ; clear () ; rep ( i , 1 , n ) { scanf ( "%d%d%d" , &u , &v , &c ) ; addedge ( u , v , c , H ) ; addedge ( v , u , c , H ) ; } scanf ( "%d" , &q ) ; rep ( i , 0 , q ) { scanf ( "%d%d" , &u , &v ) ; qu[i] = Query ( u , v ) ; addedge ( u , v , i , Q ) ; addedge ( v , u , i , Q ) ; } dfs1 ( 1 ) ; tot[1] = dp[1] ; dfs2 ( 1 ) ; dfs3 ( 1 ) ; //For ( i , 1 , n ) printf ( "%d = %d\n" , i , dp[i].sqr ) ; //For ( i , 1 , n ) printf ( "%d = %d\n" , i , tot[i].sqr ) ; rep ( i , 0 , q ) { u = qu[i].u ; v = qu[i].v ; int f = qu[i].lca ; int d = ( ( dis[u] + dis[v] - 2 * dis[f] ) % mod + mod ) % mod ; if ( f != v ) { int tmp = ( dp[v].sqr + 2LL * dp[v].sum * d + 1LL * dp[v].cnt * d % mod * d ) % mod ; printf ( "%d\n" , ( ( 2LL * tmp - tot[u].sqr ) % mod + mod ) % mod ) ; } else { Node t ; t.cnt = ( ( tot[v].cnt - dp[v].cnt ) % mod + mod ) % mod ; t.sum = ( ( tot[v].sum - dp[v].sum ) % mod + mod ) % mod ; t.sqr = ( ( tot[v].sqr - dp[v].sqr ) % mod + mod ) % mod ; int tmp = ( t.sqr + 2LL * d * t.sum + 1LL * t.cnt * d % mod * d ) % mod ; printf ( "%d\n" , ( ( tot[u].sqr - 2LL * tmp ) % mod + mod ) % mod ) ; } } } int main () { while ( ~scanf ( "%d" , &n ) ) solve () ; return 0 ; }