GP of China H Inner Product: 边分治 + 虚树dp

题意

给出两棵树 T T T T ′ T' T,求
∑ i , j ∈ [ 1 , n ] d i s ( i , j ) ∗ d i s ′ ( i , j ) \sum_{i,j \in [1,n]}{dis(i,j) * dis'(i,j)} i,j[1,n]dis(i,j)dis(i,j)

题解

T T T进行边分治,当前分治的边为 < u , v > <u,v>,边权为 w w w时,设 u u u一侧的点集为 L L L v v v一侧的点集是 R R R,则经过当前边的 d i s ( i , j ) dis(i,j) dis(i,j)对答案的贡献是:
A N S = ∑ i ∈ L , j ∈ R ( d i s ( i , u ) + d i s ( j , v ) + w ) ⋅ d i s ′ ( i , j ) = ∑ i ∈ L , j ∈ R ( d i s [ i ] + d i s [ j ] ) ⋅ d i s ′ ( i , j ) + w ⋅ ∑ i ∈ L , j ∈ R d i s ′ ( i , j ) = ∑ i ∈ L , j ∈ R ( d i s [ i ] + d i s [ j ] ) ⋅ ( d e p ′ [ i ] + d e p ′ [ j ] − 2 ⋅ d e p ′ [ l c a ′ ( i , j ) ] ) + w ⋅ ∑ i ∈ L , j ∈ R d i s ′ ( i , j ) = ∑ i ∈ L , j ∈ R ( d i s [ i ] + d i s [ j ] ) ⋅ ( d e p ′ [ i ] + d e p ′ [ j ] ) − 2 ⋅ ∑ i ∈ L , j ∈ R ( d i s [ i ] + d i s [ j ] ) ⋅ d e p ′ [ l c a ′ ( i , j ) ] + w ⋅ ∑ i ∈ L , j ∈ R d i s ′ ( i , j ) = ∑ i ∈ L , j ∈ R ( d i s [ i ] ⋅ d e p ′ [ i ] + d i s [ j ] ⋅ d e p ′ [ j ] ) + ∑ i ∈ L , j ∈ R ( d i s [ i ] ⋅ d e p ′ [ j ] + d i s [ j ] ⋅ d e p ′ [ i ] )        − 2 ⋅ ∑ i ∈ L , j ∈ R ( d i s [ i ] + d i s [ j ] ) ⋅ d e p ′ [ l c a ′ ( i , j ) ] + w ⋅ ∑ i ∈ L , j ∈ R d i s ′ ( i , j ) = ∣ R ∣ ⋅ ∑ i ∈ L d i s [ i ] ⋅ d e p ′ [ i ] + ∣ L ∣ ⋅ ∑ j ∈ R d i s [ j ] ⋅ d e p ′ [ j ] + ∑ i ∈ L d i s [ i ] ⋅ ∑ j ∈ R d e p ′ [ j ] + ∑ i ∈ L d e p ′ [ i ] ⋅ ∑ j ∈ R d i s [ j ]        − 2 ⋅ ∑ i ∈ L , j ∈ R ( d i s [ i ] + d i s [ j ] ) ⋅ d e p ′ [ l c a ′ ( i , j ) ] + w ⋅ ∑ i ∈ L , j ∈ R d i s ′ ( i , j ) \begin{aligned} ANS &= \sum_{i \in L,j \in R}{(dis(i,u) + dis(j,v) + w)\cdot dis'(i,j)}\\ &= \sum_{i \in L,j\in R}{(dis[i] + dis[j])\cdot dis'(i,j)} + w\cdot \sum_{i \in L,j\in R}{dis'(i,j)}\\ &= \sum_{i \in L,j\in R}{(dis[i] + dis[j])\cdot (dep'[i] + dep'[j] - 2 \cdot dep'[lca'(i,j)])} + w\cdot \sum_{i \in L,j\in R}{dis'(i,j)}\\ &= \sum_{i \in L,j\in R}{(dis[i] + dis[j])\cdot (dep'[i] + dep'[j])} - 2 \cdot\sum_{i \in L,j \in R}{(dis[i] + dis[j])\cdot dep'[lca'(i,j)]} + w\cdot \sum_{i \in L,j\in R}{dis'(i,j)}\\ &= \sum_{i \in L,j\in R}{(dis[i] \cdot dep'[i] + dis[j] \cdot dep'[j])} + \sum_{i \in L,j \in R}{(dis[i] \cdot dep'[j] + dis[j] \cdot dep'[i])}\\ &\ \ \ \ \ \ - 2 \cdot\sum_{i \in L,j \in R}{(dis[i] + dis[j])\cdot dep'[lca'(i,j)]} + w\cdot \sum_{i \in L,j\in R}{dis'(i,j)}\\ &= |R| \cdot\sum_{i \in L}{dis[i] \cdot dep'[i]} + |L| \cdot \sum_{j \in R}{dis[j] \cdot dep'[j]} + \sum_{i \in L}{dis[i]} \cdot \sum_{j \in R}{dep'[j]}+ \sum_{i \in L}{dep'[i]} \cdot \sum_{j \in R}{dis[j]}\\ &\ \ \ \ \ \ - 2 \cdot\sum_{i \in L,j \in R}{(dis[i] + dis[j])\cdot dep'[lca'(i,j)]} + w\cdot \sum_{i \in L,j\in R}{dis'(i,j)} \end{aligned} ANS=iL,jR(dis(i,u)+dis(j,v)+w)dis(i,j)=iL,jR(dis[i]+dis[j])dis(i,j)+wiL,jRdis(i,j)=iL,jR(dis[i]+dis[j])(dep[i]+dep[j]2dep[lca(i,j)])+wiL,jRdis(i,j)=iL,jR(dis[i]+dis[j])(dep[i]+dep[j])2iL,jR(dis[i]+dis[j])dep[lca(i,j)]+wiL,jRdis(i,j)=iL,jR(dis[i]dep[i]+dis[j]dep[j])+iL,jR(dis[i]dep[j]+dis[j]dep[i])      2iL,jR(dis[i]+dis[j])dep[lca(i,j)]+wiL,jRdis(i,j)=RiLdis[i]dep[i]+LjRdis[j]dep[j]+iLdis[i]jRdep[j]+iLdep[i]jRdis[j]      2iL,jR(dis[i]+dis[j])dep[lca(i,j)]+wiL,jRdis(i,j)
每个求和式都是可以直接DP的。。。所以用点集 L ∪ R L \cup R LR T ′ T' T上建虚树做01DP即可。
复杂度大概。。 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)。。那么问题来了。。为什么不用点分呢

#pragma GCC optimize(3)
#include
using namespace std;
typedef long long ll;
const int mod = 1e9 + 7;
const int maxn = 1e5 + 100;
int n;
vector<pair<int,int> > E2[maxn];
namespace HLD{
    int wson[maxn],sz[maxn],depth[maxn];
    ll dep[maxn];
    int fa[maxn],top[maxn];
    int dfs_clock,l[maxn],r[maxn];
    void dfs1(int u,int Fa){
        l[u] = ++dfs_clock;
        depth[u] = depth[Fa] + 1;
        sz[u] = 1;
        wson[u] = 0;
        fa[u] = Fa;
        for (auto e : E2[u]){
            int v,len;
            tie(v,len) = e;
            if (v == Fa)continue;
            dep[v] = dep[u] + len;
            if (dep[v] >= mod)dep[v] -= mod;
            dfs1(v,u);
            sz[u] += sz[v];
            if (sz[v] > sz[wson[u]])wson[u] = v;
        }
        r[u] = dfs_clock;
    }
    void dfs2(int u,int Fa,int chain){
        top[u] = chain;
        if (wson[u])dfs2(wson[u],u,chain);
        for (auto e : E2[u]){
            int v,len;
            tie(v,len) = e;
            if (v == Fa || v == wson[u])continue;
            dfs2(v,u,v);
        }
    }
    void init(int root){
        dfs_clock = 0;
        dep[root] = 0;
        dfs1(root,0);
        dfs2(root,0,root);
    }
    int lca(int x,int y){
        while (top[x] != top[y]){
            if (depth[top[x]] < depth[top[y]])swap(x,y);
            x = fa[top[x]];
        }
        if (depth[x] < depth[y])swap(x,y);
        return y;
    }
}
vector<pair<int,int> > EE1[maxn];
int first[maxn*3],des[maxn*6],llen[maxn*6],edgeid[maxn*6],nxt[maxn*6],tot;
inline void add_edge_(int u,int v,int w,int id){
    tot ++;
    des[tot] = v;
    llen[tot] = w;
    edgeid[tot] = id;
    nxt[tot] = first[u];
    first[u] = tot;
}
int cnt,edge_cnt;
bool banned[maxn * 3];
int pos[maxn * 3];
int id[maxn * 3];
ll ans = 0;
void clear(){
    tot = 0;
    for (int i=1;i<=cnt;i++)first[i] = 0;
    cnt = 0;
    for (int i=1;i<=n;i++){
        EE1[i].clear();
        E2[i].clear();
    }
    for (int i=1;i<=edge_cnt;i++)banned[i] = false;
    edge_cnt = 0;
    ans = 0;
}
inline void add_edge(int u,int v,int w){
    edge_cnt ++;
    add_edge_(u,v,w,edge_cnt);
    add_edge_(v,u,w,edge_cnt);
}
int dfs3d(int u, int fa){
    int now = ++cnt;
    pos[u] = now;
    id[now] = u;
    int pre = now;
    for (auto e : EE1[u]){
        int v,len;
        tie(v,len) = e;
        if (v == fa)continue;
        int temp = ++cnt;
        id[temp] = 0;
        add_edge(pre,temp,0);
        int vid = dfs3d(v,u);
        add_edge(temp,vid,len);
        pre = temp;
    }
    return now;
}
int sz[maxn * 3];
void dfs_sz(int u,int fa){
    sz[u] = 1;
    for (int t = first[u];t;t=nxt[t]){
        int v = des[t],e_id = edgeid[t];
        if (v == fa || banned[e_id])continue;
        dfs_sz(v,u);
        sz[u] += sz[v];
    }
}
void dfs_edge(int u,int fa,int tot_node,
              int &uu,int &vv,int &ww,int &edge_id,int &max_sz){
    for (int t = first[u];t;t=nxt[t]){
        int v = des[t],len = llen[t],e_id = edgeid[t];
        if (v == fa || banned[e_id])continue;
        int max_sz_t = max(sz[v],tot_node - sz[v]);
        if (max_sz_t < max_sz){
            max_sz = max_sz_t;
            uu = u;vv = v;
            ww = len;
            edge_id = e_id;
        }
        dfs_edge(v,u,tot_node,uu,vv,ww,edge_id,max_sz);
    }
}
ll dis[maxn * 3];
void dfs_node(int u,int fa,ll length,vector<int> &nodes){
    if (id[u])nodes.push_back(id[u]);
    dis[u] = length;
    for (int t = first[u];t;t=nxt[t]){
        int v = des[t],len = llen[t],e_id = edgeid[t];
        if (v == fa || banned[e_id])continue;
        int le = length + len;
        if (le >= mod) le -= mod;
        dfs_node(v,u,le, nodes);
    }
}
int vis[maxn];
int stk[maxn];
int fa[maxn];
ll dp_sum[maxn][2], dp_cnt[maxn][2],dp[maxn];
int color[maxn];
inline void clear(int x){
    dp[x] = 0;
    for (int c = 0;c < 2;c ++){
        dp_sum[x][c] = dp_cnt[x][c] = 0;
    }
}
void calc(int u, int v, int w){
    vector<int> L(0),R(0),nodes(0);
    dfs_node(u,0,0,L);dfs_node(v,0,0,R);
    if (L.size() == 0 || R.size() == 0)return;
    for (int x : L){
        color[x] = 1;
        vis[x] = 1;
        nodes.push_back(x);
    }
    for (int y : R){
        color[y] = 2;
        vis[y] = 1;
        nodes.push_back(y);
    }
    sort(nodes.begin(),nodes.end(),[](int x,int y){
        return HLD::l[x] < HLD::l[y];
    });
    int SZ = nodes.size();
    for (int i=1;i<SZ;i++){
        int temp = HLD::lca(nodes[i-1],nodes[i]);
        if (!vis[temp]){
            nodes.push_back(temp);
            vis[temp] = 2;
        }
    }
    if (!vis[1]){
        nodes.push_back(1);
        vis[1] = 2;
    }
    sort(nodes.begin(),nodes.end(),[](int x,int y){
        return HLD::l[x] < HLD::l[y];
    });
    int top = 1;
    stk[0] = nodes.front();
    for (int i=1;i<nodes.size(); i ++){
        while (HLD::l[nodes[i]] > HLD::r[stk[top-1]]) top--;
        fa[nodes[i]] = stk[top-1];
        stk[top ++] = nodes[i];
    }
    for (int x : nodes)clear(x);
    ll temp_ans = 0;
    ll sum = 0;
    for (int x : L) sum += 1ll * (dis[pos[x]] + w) * HLD::dep[x] % mod;
    temp_ans += sum % mod * R.size();
    sum = 0;
    for (int y : R) sum += 1ll * (dis[pos[y]] + w) * HLD::dep[y] % mod;
    temp_ans += sum % mod * L.size();
    ll sum1 = 0,sum2 = 0;
    for (int x : L)sum1 += dis[pos[x]];
    for (int y : R)sum2 += HLD::dep[y];
    sum1 %= mod;sum2 %= mod;
    temp_ans += sum1 * sum2 % mod;
    sum1 = sum2 = 0;
    for (int x : L)sum1 += HLD::dep[x];
    for (int y : R)sum2 += dis[pos[y]];
    sum1 %= mod;sum2 %= mod;
    temp_ans += sum1 * sum2 % mod;
    for (int i = nodes.size() - 1;i >=0; i--){
        int u = nodes[i], c = vis[u] == 1?color[u] - 1 : -1;
        if (c != -1){
            ll A = dis[pos[u]];
            dp[u] += A * dp_cnt[u][!c] % mod;
            dp[u] += dp_sum[u][!c];
            dp[u] += dp_cnt[u][!c] * w % mod;
            dp[u] %= mod;
            dp_sum[u][c] += A;
            if (dp_sum[u][c] >= mod)dp_sum[u][c] -= mod;
            dp_cnt[u][c] ++;
        }
        temp_ans -= 2ll * dp[u] * HLD::dep[u] % mod;
        dp[fa[u]] += dp_sum[u][0] * dp_cnt[fa[u]][1]% mod + dp_sum[fa[u]][1] * dp_cnt[u][0]% mod;
        dp[fa[u]] += dp_sum[u][1] * dp_cnt[fa[u]][0]% mod + dp_sum[fa[u]][0] * dp_cnt[u][1]% mod;
        dp[fa[u]] += (dp_cnt[fa[u]][1] * dp_cnt[u][0]% mod + dp_cnt[fa[u]][0] * dp_cnt[u][1] % mod) * w % mod;
        dp[fa[u]] %= mod;
        for (int c = 0;c < 2;c ++){
            dp_cnt[fa[u]][c] += dp_cnt[u][c];
            dp_sum[fa[u]][c] += dp_sum[u][c];
            if(dp_cnt[fa[u]][c] >= mod)dp_cnt[fa[u]][c] -= mod;
            if(dp_sum[fa[u]][c] >= mod)dp_sum[fa[u]][c] -= mod;
        }
    }
    ans += (temp_ans % mod) + mod;
    for (int x : nodes){
        vis[x] = 0;
    }
}
void dfs(int root){
    dfs_sz(root,0);
    int uu,vv,ww,e_id,max_sz = mod;
    int node_cnt = sz[root];
    if (node_cnt == 1)return;
    dfs_edge(root,0,node_cnt,uu,vv,ww,e_id,max_sz);
    banned[e_id] = true;
    calc(uu, vv, ww);
    dfs(uu);dfs(vv);
}
inline void read(int &x){
    x = 0;
    char ch = getchar();
    while (ch < '0' || ch > '9')ch = getchar();
    while (ch >= '0' && ch <= '9')x = x * 10 + ch - '0',ch = getchar();
}
void work(){
    read(n);
    for (int i=1;i<n;i++){
        int u,v,len;
        read(u);read(v);read(len);
        EE1[u].push_back(make_pair(v,len));
        EE1[v].push_back(make_pair(u,len));
    }
    for (int i=1;i<n;i++){
        int u,v,len;
        read(u);read(v);read(len);
        E2[u].push_back(make_pair(v,len));
        E2[v].push_back(make_pair(u,len));
    }
    HLD::init(1);
    int root = dfs3d(1,0);
    int max_d = -1;
    for (int i=1;i<=cnt;i++){
        int d = 0;
        for (int t = first[i];t;t=nxt[t]){
            d ++;
        }
        max_d = max(max_d,d);
    }
    dfs(root);
    ans %= mod;
    ans = ans * 2 % mod;
    printf("%lld\n",ans);
}
int main(){
    int T;
    read(T);
    while (T--){
        work();
        clear();
    }
    return 0;
}

你可能感兴趣的:(专题练习)