给出两棵树 T T T和 T ′ T' T′,求
∑ i , j ∈ [ 1 , n ] d i s ( i , j ) ∗ d i s ′ ( i , j ) \sum_{i,j \in [1,n]}{dis(i,j) * dis'(i,j)} i,j∈[1,n]∑dis(i,j)∗dis′(i,j)
对 T T T进行边分治,当前分治的边为 < u , v > <u,v>,边权为 w w w时,设 u u u一侧的点集为 L L L, v v v一侧的点集是 R R R,则经过当前边的 d i s ( i , j ) dis(i,j) dis(i,j)对答案的贡献是:
A N S = ∑ i ∈ L , j ∈ R ( d i s ( i , u ) + d i s ( j , v ) + w ) ⋅ d i s ′ ( i , j ) = ∑ i ∈ L , j ∈ R ( d i s [ i ] + d i s [ j ] ) ⋅ d i s ′ ( i , j ) + w ⋅ ∑ i ∈ L , j ∈ R d i s ′ ( i , j ) = ∑ i ∈ L , j ∈ R ( d i s [ i ] + d i s [ j ] ) ⋅ ( d e p ′ [ i ] + d e p ′ [ j ] − 2 ⋅ d e p ′ [ l c a ′ ( i , j ) ] ) + w ⋅ ∑ i ∈ L , j ∈ R d i s ′ ( i , j ) = ∑ i ∈ L , j ∈ R ( d i s [ i ] + d i s [ j ] ) ⋅ ( d e p ′ [ i ] + d e p ′ [ j ] ) − 2 ⋅ ∑ i ∈ L , j ∈ R ( d i s [ i ] + d i s [ j ] ) ⋅ d e p ′ [ l c a ′ ( i , j ) ] + w ⋅ ∑ i ∈ L , j ∈ R d i s ′ ( i , j ) = ∑ i ∈ L , j ∈ R ( d i s [ i ] ⋅ d e p ′ [ i ] + d i s [ j ] ⋅ d e p ′ [ j ] ) + ∑ i ∈ L , j ∈ R ( d i s [ i ] ⋅ d e p ′ [ j ] + d i s [ j ] ⋅ d e p ′ [ i ] ) − 2 ⋅ ∑ i ∈ L , j ∈ R ( d i s [ i ] + d i s [ j ] ) ⋅ d e p ′ [ l c a ′ ( i , j ) ] + w ⋅ ∑ i ∈ L , j ∈ R d i s ′ ( i , j ) = ∣ R ∣ ⋅ ∑ i ∈ L d i s [ i ] ⋅ d e p ′ [ i ] + ∣ L ∣ ⋅ ∑ j ∈ R d i s [ j ] ⋅ d e p ′ [ j ] + ∑ i ∈ L d i s [ i ] ⋅ ∑ j ∈ R d e p ′ [ j ] + ∑ i ∈ L d e p ′ [ i ] ⋅ ∑ j ∈ R d i s [ j ] − 2 ⋅ ∑ i ∈ L , j ∈ R ( d i s [ i ] + d i s [ j ] ) ⋅ d e p ′ [ l c a ′ ( i , j ) ] + w ⋅ ∑ i ∈ L , j ∈ R d i s ′ ( i , j ) \begin{aligned} ANS &= \sum_{i \in L,j \in R}{(dis(i,u) + dis(j,v) + w)\cdot dis'(i,j)}\\ &= \sum_{i \in L,j\in R}{(dis[i] + dis[j])\cdot dis'(i,j)} + w\cdot \sum_{i \in L,j\in R}{dis'(i,j)}\\ &= \sum_{i \in L,j\in R}{(dis[i] + dis[j])\cdot (dep'[i] + dep'[j] - 2 \cdot dep'[lca'(i,j)])} + w\cdot \sum_{i \in L,j\in R}{dis'(i,j)}\\ &= \sum_{i \in L,j\in R}{(dis[i] + dis[j])\cdot (dep'[i] + dep'[j])} - 2 \cdot\sum_{i \in L,j \in R}{(dis[i] + dis[j])\cdot dep'[lca'(i,j)]} + w\cdot \sum_{i \in L,j\in R}{dis'(i,j)}\\ &= \sum_{i \in L,j\in R}{(dis[i] \cdot dep'[i] + dis[j] \cdot dep'[j])} + \sum_{i \in L,j \in R}{(dis[i] \cdot dep'[j] + dis[j] \cdot dep'[i])}\\ &\ \ \ \ \ \ - 2 \cdot\sum_{i \in L,j \in R}{(dis[i] + dis[j])\cdot dep'[lca'(i,j)]} + w\cdot \sum_{i \in L,j\in R}{dis'(i,j)}\\ &= |R| \cdot\sum_{i \in L}{dis[i] \cdot dep'[i]} + |L| \cdot \sum_{j \in R}{dis[j] \cdot dep'[j]} + \sum_{i \in L}{dis[i]} \cdot \sum_{j \in R}{dep'[j]}+ \sum_{i \in L}{dep'[i]} \cdot \sum_{j \in R}{dis[j]}\\ &\ \ \ \ \ \ - 2 \cdot\sum_{i \in L,j \in R}{(dis[i] + dis[j])\cdot dep'[lca'(i,j)]} + w\cdot \sum_{i \in L,j\in R}{dis'(i,j)} \end{aligned} ANS=i∈L,j∈R∑(dis(i,u)+dis(j,v)+w)⋅dis′(i,j)=i∈L,j∈R∑(dis[i]+dis[j])⋅dis′(i,j)+w⋅i∈L,j∈R∑dis′(i,j)=i∈L,j∈R∑(dis[i]+dis[j])⋅(dep′[i]+dep′[j]−2⋅dep′[lca′(i,j)])+w⋅i∈L,j∈R∑dis′(i,j)=i∈L,j∈R∑(dis[i]+dis[j])⋅(dep′[i]+dep′[j])−2⋅i∈L,j∈R∑(dis[i]+dis[j])⋅dep′[lca′(i,j)]+w⋅i∈L,j∈R∑dis′(i,j)=i∈L,j∈R∑(dis[i]⋅dep′[i]+dis[j]⋅dep′[j])+i∈L,j∈R∑(dis[i]⋅dep′[j]+dis[j]⋅dep′[i]) −2⋅i∈L,j∈R∑(dis[i]+dis[j])⋅dep′[lca′(i,j)]+w⋅i∈L,j∈R∑dis′(i,j)=∣R∣⋅i∈L∑dis[i]⋅dep′[i]+∣L∣⋅j∈R∑dis[j]⋅dep′[j]+i∈L∑dis[i]⋅j∈R∑dep′[j]+i∈L∑dep′[i]⋅j∈R∑dis[j] −2⋅i∈L,j∈R∑(dis[i]+dis[j])⋅dep′[lca′(i,j)]+w⋅i∈L,j∈R∑dis′(i,j)
每个求和式都是可以直接DP的。。。所以用点集 L ∪ R L \cup R L∪R在 T ′ T' T′上建虚树做01DP即可。
复杂度大概。。 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)。。那么问题来了。。为什么不用点分呢
#pragma GCC optimize(3)
#include
using namespace std;
typedef long long ll;
const int mod = 1e9 + 7;
const int maxn = 1e5 + 100;
int n;
vector<pair<int,int> > E2[maxn];
namespace HLD{
int wson[maxn],sz[maxn],depth[maxn];
ll dep[maxn];
int fa[maxn],top[maxn];
int dfs_clock,l[maxn],r[maxn];
void dfs1(int u,int Fa){
l[u] = ++dfs_clock;
depth[u] = depth[Fa] + 1;
sz[u] = 1;
wson[u] = 0;
fa[u] = Fa;
for (auto e : E2[u]){
int v,len;
tie(v,len) = e;
if (v == Fa)continue;
dep[v] = dep[u] + len;
if (dep[v] >= mod)dep[v] -= mod;
dfs1(v,u);
sz[u] += sz[v];
if (sz[v] > sz[wson[u]])wson[u] = v;
}
r[u] = dfs_clock;
}
void dfs2(int u,int Fa,int chain){
top[u] = chain;
if (wson[u])dfs2(wson[u],u,chain);
for (auto e : E2[u]){
int v,len;
tie(v,len) = e;
if (v == Fa || v == wson[u])continue;
dfs2(v,u,v);
}
}
void init(int root){
dfs_clock = 0;
dep[root] = 0;
dfs1(root,0);
dfs2(root,0,root);
}
int lca(int x,int y){
while (top[x] != top[y]){
if (depth[top[x]] < depth[top[y]])swap(x,y);
x = fa[top[x]];
}
if (depth[x] < depth[y])swap(x,y);
return y;
}
}
vector<pair<int,int> > EE1[maxn];
int first[maxn*3],des[maxn*6],llen[maxn*6],edgeid[maxn*6],nxt[maxn*6],tot;
inline void add_edge_(int u,int v,int w,int id){
tot ++;
des[tot] = v;
llen[tot] = w;
edgeid[tot] = id;
nxt[tot] = first[u];
first[u] = tot;
}
int cnt,edge_cnt;
bool banned[maxn * 3];
int pos[maxn * 3];
int id[maxn * 3];
ll ans = 0;
void clear(){
tot = 0;
for (int i=1;i<=cnt;i++)first[i] = 0;
cnt = 0;
for (int i=1;i<=n;i++){
EE1[i].clear();
E2[i].clear();
}
for (int i=1;i<=edge_cnt;i++)banned[i] = false;
edge_cnt = 0;
ans = 0;
}
inline void add_edge(int u,int v,int w){
edge_cnt ++;
add_edge_(u,v,w,edge_cnt);
add_edge_(v,u,w,edge_cnt);
}
int dfs3d(int u, int fa){
int now = ++cnt;
pos[u] = now;
id[now] = u;
int pre = now;
for (auto e : EE1[u]){
int v,len;
tie(v,len) = e;
if (v == fa)continue;
int temp = ++cnt;
id[temp] = 0;
add_edge(pre,temp,0);
int vid = dfs3d(v,u);
add_edge(temp,vid,len);
pre = temp;
}
return now;
}
int sz[maxn * 3];
void dfs_sz(int u,int fa){
sz[u] = 1;
for (int t = first[u];t;t=nxt[t]){
int v = des[t],e_id = edgeid[t];
if (v == fa || banned[e_id])continue;
dfs_sz(v,u);
sz[u] += sz[v];
}
}
void dfs_edge(int u,int fa,int tot_node,
int &uu,int &vv,int &ww,int &edge_id,int &max_sz){
for (int t = first[u];t;t=nxt[t]){
int v = des[t],len = llen[t],e_id = edgeid[t];
if (v == fa || banned[e_id])continue;
int max_sz_t = max(sz[v],tot_node - sz[v]);
if (max_sz_t < max_sz){
max_sz = max_sz_t;
uu = u;vv = v;
ww = len;
edge_id = e_id;
}
dfs_edge(v,u,tot_node,uu,vv,ww,edge_id,max_sz);
}
}
ll dis[maxn * 3];
void dfs_node(int u,int fa,ll length,vector<int> &nodes){
if (id[u])nodes.push_back(id[u]);
dis[u] = length;
for (int t = first[u];t;t=nxt[t]){
int v = des[t],len = llen[t],e_id = edgeid[t];
if (v == fa || banned[e_id])continue;
int le = length + len;
if (le >= mod) le -= mod;
dfs_node(v,u,le, nodes);
}
}
int vis[maxn];
int stk[maxn];
int fa[maxn];
ll dp_sum[maxn][2], dp_cnt[maxn][2],dp[maxn];
int color[maxn];
inline void clear(int x){
dp[x] = 0;
for (int c = 0;c < 2;c ++){
dp_sum[x][c] = dp_cnt[x][c] = 0;
}
}
void calc(int u, int v, int w){
vector<int> L(0),R(0),nodes(0);
dfs_node(u,0,0,L);dfs_node(v,0,0,R);
if (L.size() == 0 || R.size() == 0)return;
for (int x : L){
color[x] = 1;
vis[x] = 1;
nodes.push_back(x);
}
for (int y : R){
color[y] = 2;
vis[y] = 1;
nodes.push_back(y);
}
sort(nodes.begin(),nodes.end(),[](int x,int y){
return HLD::l[x] < HLD::l[y];
});
int SZ = nodes.size();
for (int i=1;i<SZ;i++){
int temp = HLD::lca(nodes[i-1],nodes[i]);
if (!vis[temp]){
nodes.push_back(temp);
vis[temp] = 2;
}
}
if (!vis[1]){
nodes.push_back(1);
vis[1] = 2;
}
sort(nodes.begin(),nodes.end(),[](int x,int y){
return HLD::l[x] < HLD::l[y];
});
int top = 1;
stk[0] = nodes.front();
for (int i=1;i<nodes.size(); i ++){
while (HLD::l[nodes[i]] > HLD::r[stk[top-1]]) top--;
fa[nodes[i]] = stk[top-1];
stk[top ++] = nodes[i];
}
for (int x : nodes)clear(x);
ll temp_ans = 0;
ll sum = 0;
for (int x : L) sum += 1ll * (dis[pos[x]] + w) * HLD::dep[x] % mod;
temp_ans += sum % mod * R.size();
sum = 0;
for (int y : R) sum += 1ll * (dis[pos[y]] + w) * HLD::dep[y] % mod;
temp_ans += sum % mod * L.size();
ll sum1 = 0,sum2 = 0;
for (int x : L)sum1 += dis[pos[x]];
for (int y : R)sum2 += HLD::dep[y];
sum1 %= mod;sum2 %= mod;
temp_ans += sum1 * sum2 % mod;
sum1 = sum2 = 0;
for (int x : L)sum1 += HLD::dep[x];
for (int y : R)sum2 += dis[pos[y]];
sum1 %= mod;sum2 %= mod;
temp_ans += sum1 * sum2 % mod;
for (int i = nodes.size() - 1;i >=0; i--){
int u = nodes[i], c = vis[u] == 1?color[u] - 1 : -1;
if (c != -1){
ll A = dis[pos[u]];
dp[u] += A * dp_cnt[u][!c] % mod;
dp[u] += dp_sum[u][!c];
dp[u] += dp_cnt[u][!c] * w % mod;
dp[u] %= mod;
dp_sum[u][c] += A;
if (dp_sum[u][c] >= mod)dp_sum[u][c] -= mod;
dp_cnt[u][c] ++;
}
temp_ans -= 2ll * dp[u] * HLD::dep[u] % mod;
dp[fa[u]] += dp_sum[u][0] * dp_cnt[fa[u]][1]% mod + dp_sum[fa[u]][1] * dp_cnt[u][0]% mod;
dp[fa[u]] += dp_sum[u][1] * dp_cnt[fa[u]][0]% mod + dp_sum[fa[u]][0] * dp_cnt[u][1]% mod;
dp[fa[u]] += (dp_cnt[fa[u]][1] * dp_cnt[u][0]% mod + dp_cnt[fa[u]][0] * dp_cnt[u][1] % mod) * w % mod;
dp[fa[u]] %= mod;
for (int c = 0;c < 2;c ++){
dp_cnt[fa[u]][c] += dp_cnt[u][c];
dp_sum[fa[u]][c] += dp_sum[u][c];
if(dp_cnt[fa[u]][c] >= mod)dp_cnt[fa[u]][c] -= mod;
if(dp_sum[fa[u]][c] >= mod)dp_sum[fa[u]][c] -= mod;
}
}
ans += (temp_ans % mod) + mod;
for (int x : nodes){
vis[x] = 0;
}
}
void dfs(int root){
dfs_sz(root,0);
int uu,vv,ww,e_id,max_sz = mod;
int node_cnt = sz[root];
if (node_cnt == 1)return;
dfs_edge(root,0,node_cnt,uu,vv,ww,e_id,max_sz);
banned[e_id] = true;
calc(uu, vv, ww);
dfs(uu);dfs(vv);
}
inline void read(int &x){
x = 0;
char ch = getchar();
while (ch < '0' || ch > '9')ch = getchar();
while (ch >= '0' && ch <= '9')x = x * 10 + ch - '0',ch = getchar();
}
void work(){
read(n);
for (int i=1;i<n;i++){
int u,v,len;
read(u);read(v);read(len);
EE1[u].push_back(make_pair(v,len));
EE1[v].push_back(make_pair(u,len));
}
for (int i=1;i<n;i++){
int u,v,len;
read(u);read(v);read(len);
E2[u].push_back(make_pair(v,len));
E2[v].push_back(make_pair(u,len));
}
HLD::init(1);
int root = dfs3d(1,0);
int max_d = -1;
for (int i=1;i<=cnt;i++){
int d = 0;
for (int t = first[i];t;t=nxt[t]){
d ++;
}
max_d = max(max_d,d);
}
dfs(root);
ans %= mod;
ans = ans * 2 % mod;
printf("%lld\n",ans);
}
int main(){
int T;
read(T);
while (T--){
work();
clear();
}
return 0;
}