怎么说呢,非常神的一道题
我们先忽略修改操作,考虑换根操作,假设我们的根从\(u\)换成了\(v\)那么可以注意到发生变化的\(sz\)只有两个,\(u\)和\(v\)
于是我们有这次操作后的点权平方总和为\(\sum_{i=1}^nsz_i^2\)
变化的权值则为\(sum^2->(sum-sz_v)^2,sz_v^2->sum^2\)
你会发现这个还是非常不好维护
但是貌似\(\sum_{i=1}^nsz_i*(sum-sz_i)\)是个定值
这是因为一次换根只会影响到\(u,v\)两个点,所以变化的只有那两个点,但他们的变化恰好是一个从\(sz\)变成了\(sum-sz\),另一个则变成了\(sum\)
所以这指示这对于任意的一个根\(\sum_{i=1}^nsz_i*(sum-sz_i)\)是一个定值
仔细思考会发现这个定值:
\[\sum_{i=1}^nsum*sz_i-sz_i^2=S\]
\[\sum sz^2=\sum_{i=1}^nsum*sz_i-S\]
由于\(S\)为定值,假设我们可以求出来,那么我们就只需要维护\(\sum sum*sz_i\)
显然的是\(sum\)可以提到前面去,于是我们就只需要维护\(sz\)之和
进一步思考,在换根了之后,一个点对\(sz\)之和的贡献貌似是其到根的距离\(+1\)次
于是我们就有:
\[\sum sz = \sum w_i*(dist(i,u)+1)\]
其中\(dist\)表示两点之间的路径,\(u\)为根
\(+1\)是可以提出来的,于是我们只需要管前面那一坨
然而有趣的是前面这一坨是要我们对于每个\(u\)维护\(\sum dist(i,u)*w_i\)
这个东西好像可以用动态点分治来维护?
我们考虑建出点分树,对每个点维护一下子树内到其的\(f(u)=\sum dist(i,u)*w_i\),\(s(u)=\sum_{}w_i\),\(g(u)=\sum dist(i,fa)*w_i\)
好像就可以转移了
\[f(u)=f(fa)-g(u)+f(u)+(s(fa)-s(u))*dist(fa,u)\]
在点分树上暴力转移的复杂度是\(O(\log^2 n)\)的
每次修改就暴力修改这三个值即可,复杂度同样\(O(\log^2n)\)
接下来考虑如何在修改之后维护\(\rm S\)
仔细思考\(\rm S\)的定义发现它是对于每个点其子树内的点权和*子树外的点权和
实际上是\(\sum w_i*w_j*dist(i,j)\)
\[\sum_{i=1}^n\sum_{j=i+1}^{n}w_i*w_j*dist(i,j)\]
考虑修改,你会发现貌似对于绝大部分的点对其对答案的贡献也都没有变
貌似唯一变化的就是\(\sum_{i=1}^nw_u*w_i*dist(u,i)\)
而且因为只变了\(w_u\),所以我们提出来就是\(w_u*\sum_{i=1}^nw_i*dist(i,u)\)
这个就是答案的变化率
于是好像和换根要维护的东西是一样的...
所以也是点分树上暴力跳然后修改...
复杂度\(O(\log ^2n)\)
总体复杂度\(O(n\log^2n+q\log^2n)\)
#include
using namespace std ;
#define Next( i, x ) for( register int i = head[x]; i; i = e[i].next )
#define rep( i, s, t ) for( register int i = s; i <= t; ++ i )
#define re register
#define int long long
inline int gi() {
char cc = getchar() ; int cn = 0, flus = 1 ;
while( cc > '9' || cc < '0' ) { if( cc == '-' ) flus = - flus ; cc = getchar() ; }
while( cc <= '9' && cc >= '0' ) cn = cn * 10 + cc - '0', cc = getchar() ;
return cn * flus ;
}
const int N = 2e5 + 5 ;
int n, q, head[N], w[N], dp[N], vis[N], RS ;
int d[N], f[N], g[N], fa[N], cnt, rt, root, sum, S ;
int Fa[N], sz[N], Top[N], Son[N], dep[N], sw[N] ;
struct E {
int to, next ;
} e[N * 2] ;
inline void add( int x, int y ) {
e[++ cnt] = (E){ y, head[x] }, head[x] = cnt ,
e[++ cnt] = (E){ x, head[y] }, head[y] = cnt ;
}
inline void dfs1( int x, int ff ) {
dep[x] = dep[ff] + 1, sz[x] = 1, Fa[x] = ff, sw[x] = w[x] ;
Next( i, x ) {
int v = e[i].to ; if( v == ff ) continue ;
dfs1( v, x ), sz[x] += sz[v], sw[x] += sw[v], d[x] += d[v] ;
if( sz[v] > sz[Son[x]] ) Son[x] = v ;
}
S += ( RS - sw[x] ) * sw[x] ;
}
inline void dfs2( int x, int ff ) {
Top[x] = ff ;
if( Son[x] ) dfs2( Son[x], ff ) ;
Next( i, x ) {
int v = e[i].to ; if( v == Fa[x] || v == Son[x] ) continue ;
dfs2( v, v ) ;
}
}
int LCA( int x, int y ) {
while( Top[x] != Top[y] ) {
if( dep[Top[x]] < dep[Top[y]] ) swap( x, y ) ;
x = Fa[Top[x]] ;
}
return ( dep[x] > dep[y] ) ? y : x ;
}
inline void get_rt( int x, int ff ) {
sz[x] = 1, dp[x] = 0 ;
Next( i, x ) {
int v = e[i].to ; if( v == ff || vis[v] ) continue ;
get_rt( v, x ), sz[x] += sz[v] ;
dp[x] = max( sz[v], dp[x] ) ;
}
dp[x] = max( dp[x], sum - sz[x] ) ;
if( dp[x] <= dp[rt] ) rt = x ;
}
inline void solve( int x ) {
vis[x] = 1 ;
Next( i, x ) {
int v = e[i].to ; if( vis[v] ) continue ;
rt = 0, dp[0] = sum = sz[v], get_rt( v, x ),
fa[rt] = x, solve( rt ) ;
}
}
inline int dist( int x, int y ) {
return dep[x] + dep[y] - 2 * dep[LCA(x, y)] ;
}
void Init( int x ) {
rep( i, 1, n ) {
int u = i, fr ; d[i] += w[i] ;
while( fa[u] ) {
fr = dist( i, fa[u] ), f[fa[u]] += fr * w[i],
d[fa[u]] += w[i], g[u] += fr * w[i], u = fa[u] ;
}
}
}
int Query( int x ) {
int Ans = f[x], u = x ;
while( fa[u] ) {
Ans += ( f[fa[u]] - g[u] ) ;
Ans += ( d[fa[u]] - d[u] ) * dist( fa[u], x ) ;
u = fa[u] ;
}
return Ans ;
}
void Update( int x, int p ) {
int u = x, fr, y = p - w[x] ;
int ru = Query( x ) ;
S += ru * y, d[x] += y, RS += y ; ;
while( fa[u] ) {
fr = dist( x, fa[u] ), f[fa[u]] += fr * y,
d[fa[u]] += y, g[u] += fr * y, u = fa[u] ;
}
w[x] = p ;
}
signed main() {
sum = n = gi(), q = gi(), rt = 0, dp[0] = n + 1 ; int opt, x, y ;
rep( i, 2, n ) x = gi(), y = gi(), add( x, y ) ;
rep( i, 1, n ) w[i] = gi(), RS += w[i] ;
dfs1( 1, 1 ), dfs2( 1, 1 ), get_rt( 1, 1 ),
root = rt, solve( rt ), Init(root) ;
while( q-- ) {
opt = gi(), x = gi() ;
if( opt == 1 ) y = gi(), Update( x, y ) ;
else printf("%lld\n", ( Query( x ) + RS ) * RS - S ) ;
}
return 0 ;
}