洛谷 P3676 【小清新数据结构题】

怎么说呢,非常神的一道题

我们先忽略修改操作,考虑换根操作,假设我们的根从\(u\)换成了\(v\)那么可以注意到发生变化的\(sz\)只有两个,\(u\)\(v\)

于是我们有这次操作后的点权平方总和为\(\sum_{i=1}^nsz_i^2\)

变化的权值则为\(sum^2->(sum-sz_v)^2,sz_v^2->sum^2\)

你会发现这个还是非常不好维护

但是貌似\(\sum_{i=1}^nsz_i*(sum-sz_i)\)是个定值

这是因为一次换根只会影响到\(u,v\)两个点,所以变化的只有那两个点,但他们的变化恰好是一个从\(sz\)变成了\(sum-sz\),另一个则变成了\(sum\)

所以这指示这对于任意的一个根\(\sum_{i=1}^nsz_i*(sum-sz_i)\)是一个定值

仔细思考会发现这个定值:

\[\sum_{i=1}^nsum*sz_i-sz_i^2=S\]

\[\sum sz^2=\sum_{i=1}^nsum*sz_i-S\]

由于\(S\)为定值,假设我们可以求出来,那么我们就只需要维护\(\sum sum*sz_i\)

显然的是\(sum\)可以提到前面去,于是我们就只需要维护\(sz\)之和

进一步思考,在换根了之后,一个点对\(sz\)之和的贡献貌似是其到根的距离\(+1\)

于是我们就有:

\[\sum sz = \sum w_i*(dist(i,u)+1)\]

其中\(dist\)表示两点之间的路径,\(u\)为根

\(+1\)是可以提出来的,于是我们只需要管前面那一坨

然而有趣的是前面这一坨是要我们对于每个\(u\)维护\(\sum dist(i,u)*w_i\)

这个东西好像可以用动态点分治来维护?

我们考虑建出点分树,对每个点维护一下子树内到其的\(f(u)=\sum dist(i,u)*w_i\)\(s(u)=\sum_{}w_i\)\(g(u)=\sum dist(i,fa)*w_i\)

好像就可以转移了

\[f(u)=f(fa)-g(u)+f(u)+(s(fa)-s(u))*dist(fa,u)\]

在点分树上暴力转移的复杂度是\(O(\log^2 n)\)

每次修改就暴力修改这三个值即可,复杂度同样\(O(\log^2n)\)

接下来考虑如何在修改之后维护\(\rm S\)

仔细思考\(\rm S\)的定义发现它是对于每个点其子树内的点权和*子树外的点权和

实际上是\(\sum w_i*w_j*dist(i,j)\)

\[\sum_{i=1}^n\sum_{j=i+1}^{n}w_i*w_j*dist(i,j)\]

考虑修改,你会发现貌似对于绝大部分的点对其对答案的贡献也都没有变

貌似唯一变化的就是\(\sum_{i=1}^nw_u*w_i*dist(u,i)\)

而且因为只变了\(w_u\),所以我们提出来就是\(w_u*\sum_{i=1}^nw_i*dist(i,u)\)

这个就是答案的变化率

于是好像和换根要维护的东西是一样的...

所以也是点分树上暴力跳然后修改...

复杂度\(O(\log ^2n)\)

总体复杂度\(O(n\log^2n+q\log^2n)\)

#include
using namespace std ;
#define Next( i, x ) for( register int i = head[x]; i; i = e[i].next )
#define rep( i, s, t ) for( register int i = s; i <= t; ++ i )
#define re register
#define int long long
inline int gi() {
    char cc = getchar() ; int cn = 0, flus = 1 ; 
    while( cc > '9' || cc < '0' ) { if( cc == '-' ) flus = - flus ; cc = getchar() ; }
    while( cc <= '9' && cc >= '0' ) cn = cn * 10 + cc - '0', cc = getchar() ; 
    return cn * flus ; 
}
const int N = 2e5 + 5 ; 
int n, q, head[N], w[N], dp[N], vis[N], RS ; 
int d[N], f[N], g[N], fa[N], cnt, rt, root, sum, S ; 
int Fa[N], sz[N], Top[N], Son[N], dep[N], sw[N] ; 
struct E {
    int to, next ; 
} e[N * 2] ;
inline void add( int x, int y ) {
    e[++ cnt] = (E){ y, head[x] }, head[x] = cnt ,
    e[++ cnt] = (E){ x, head[y] }, head[y] = cnt ; 
}
inline void dfs1( int x, int ff ) {
    dep[x] = dep[ff] + 1, sz[x] = 1, Fa[x] = ff, sw[x] = w[x] ; 
    Next( i, x ) {
        int v = e[i].to ; if( v == ff ) continue ; 
        dfs1( v, x ), sz[x] += sz[v], sw[x] += sw[v], d[x] += d[v] ; 
        if( sz[v] > sz[Son[x]] ) Son[x] = v ;  
    }
    S += ( RS - sw[x] ) * sw[x] ; 
}
inline void dfs2( int x, int ff ) {
    Top[x] = ff ; 
    if( Son[x] ) dfs2( Son[x], ff ) ;
    Next( i, x ) {
        int v = e[i].to ; if( v == Fa[x] || v == Son[x] ) continue ;
        dfs2( v, v ) ;  
    } 
}
int LCA( int x, int y ) {
    while( Top[x] != Top[y] ) {
        if( dep[Top[x]] < dep[Top[y]] ) swap( x, y ) ;
        x = Fa[Top[x]] ; 
    }
    return ( dep[x] > dep[y] ) ? y : x ; 
}
inline void get_rt( int x, int ff ) {
    sz[x] = 1, dp[x] = 0 ; 
    Next( i, x ) {
        int v = e[i].to ; if( v == ff || vis[v] ) continue ; 
        get_rt( v, x ), sz[x] += sz[v] ;
        dp[x] = max( sz[v], dp[x] ) ;
    }
    dp[x] = max( dp[x], sum - sz[x] ) ;
    if( dp[x] <= dp[rt] ) rt = x ; 
}
inline void solve( int x ) {
    vis[x] = 1 ; 
    Next( i, x ) {
        int v = e[i].to ; if( vis[v] ) continue ; 
        rt = 0, dp[0] = sum = sz[v], get_rt( v, x ),
        fa[rt] = x, solve( rt ) ;
    }
}
inline int dist( int x, int y ) {
    return dep[x] + dep[y] - 2 * dep[LCA(x, y)] ; 
}
void Init( int x ) {
    rep( i, 1, n ) {
        int u = i, fr ; d[i] += w[i] ;
        while( fa[u] ) {
            fr = dist( i, fa[u] ), f[fa[u]] += fr * w[i], 
            d[fa[u]] += w[i], g[u] += fr * w[i], u = fa[u] ;
        }
    }
}
int Query( int x ) {
    int Ans = f[x], u = x ; 
    while( fa[u] ) {
        Ans += ( f[fa[u]] - g[u] ) ;
        Ans += ( d[fa[u]] - d[u] ) * dist( fa[u], x ) ;
        u = fa[u] ; 
    }
    return Ans ; 
}
void Update( int x, int p ) {
    int u = x, fr, y = p - w[x] ; 
    int ru = Query( x ) ;
    S += ru * y, d[x] += y, RS += y ; ;
    while( fa[u] ) {
        fr = dist( x, fa[u] ), f[fa[u]] += fr * y, 
        d[fa[u]] += y, g[u] += fr * y, u = fa[u] ;
    }
    w[x] = p ; 
}
signed main() {
    sum = n = gi(), q = gi(), rt = 0, dp[0] = n + 1 ; int opt, x, y ; 
    rep( i, 2, n ) x = gi(), y = gi(), add( x, y ) ;
    rep( i, 1, n ) w[i] = gi(), RS += w[i] ; 
    dfs1( 1, 1 ), dfs2( 1, 1 ), get_rt( 1, 1 ), 
    root = rt, solve( rt ), Init(root) ; 
    while( q-- ) {
        opt = gi(), x = gi() ;
        if( opt == 1 ) y = gi(), Update( x, y ) ; 
        else printf("%lld\n", ( Query( x ) + RS ) * RS - S ) ;
    }
    return 0 ; 
}

你可能感兴趣的:(洛谷 P3676 【小清新数据结构题】)