题目大意:
给定一张无向图,图中点分为黑点和白点。对于第i条边,其边长为2i,问所有黑白点之间最短路之和。
这个题仔细思考一下,不难发现对于第i条边,有
21 + 22 + … + 2i-1 < 2i
即前面所有的边权和加起来也没有这一条边这么多,因此当第i条边加入时,可以连通的黑点和白点间产生的路径就是最短路径。如果这条边连接的两个顶点已经连通,则不用加入改点。
分析到这里,前面的部分就可以使用并查集直接维护一颗生成树了。接着这个问题就变成了树上统计问题。
生成树过后,问题就转变为在一棵无根树上,所有的黑白点对之间的距离和是多少。
为了解决这个问题,我们可以在每一个节点统计几个值:
不难得到以下转移
对于节点k,其父节点为f:
子树节点颜色数量统计有:
n o d e [ k ] . d p [ 0 ] [ 1 ] = ∑ e ( j , k ) , j ≠ f n o d e [ j ] . d p [ 0 ] [ 1 ] node[k].dp[0][1] = \sum_{e(j,k),j≠f} node[j].dp[0][1] node[k].dp[0][1]=e(j,k),j=f∑node[j].dp[0][1]
n o d e [ k ] . d p [ 1 ] [ 1 ] = ∑ e ( j , k ) , j ≠ f n o d e [ j ] . d p [ 1 ] [ 1 ] node[k].dp[1][1] = \sum_{e(j,k),j≠f} node[j].dp[1][1] node[k].dp[1][1]=e(j,k),j=f∑node[j].dp[1][1]
子树距离和统计有:
n o d e [ k ] . d p [ 0 ] [ 0 ] = ∑ e ( j , k ) , j ≠ f ( n o d e [ j ] . d p [ 0 ] [ 0 ] + n o d e [ j ] . d p [ 0 ] [ 1 ] ∗ w ( e ) ) node[k].dp[0][0] = \sum_{e(j,k),j≠f} (node[j].dp[0][0] +node[j].dp[0][1]*w(e)) node[k].dp[0][0]=e(j,k),j=f∑(node[j].dp[0][0]+node[j].dp[0][1]∗w(e))
n o d e [ k ] . d p [ 1 ] [ 0 ] = ∑ e ( j , k ) , j ≠ f ( n o d e [ j ] . d p [ 1 ] [ 0 ] + n o d e [ j ] . d p [ 1 ] [ 1 ] ∗ w ( e ) ) node[k].dp[1][0] = \sum_{e(j,k),j≠f} (node[j].dp[1][0] +node[j].dp[1][1]*w(e)) node[k].dp[1][0]=e(j,k),j=f∑(node[j].dp[1][0]+node[j].dp[1][1]∗w(e))
子树中答案统计有:
n o d e [ k ] . s u m = ∑ e ( j , k ) , j ≠ f ( n o d e [ j ] . s u m + ( n o d e [ k ] . d p [ 0 ] [ 1 ] − n o d e [ j ] . d p [ 0 ] [ 1 ] ) ∗ ( n o d e [ j ] . d p [ 1 ] [ 0 ] + w ( e ) ∗ n o d e [ j ] . d p [ 1 ] [ 1 ] ) + ( n o d e [ k ] . d p [ 1 ] [ 1 ] − n o d e [ j ] . d p [ 1 ] [ 1 ] ) ∗ ( n o d e [ j ] . d p [ 0 ] [ 0 ] + w ∗ n o d e [ j ] . d p [ 0 ] [ 1 ] ) ) ) node[k].sum = \sum_{e(j,k),j≠f} (node[j].sum + (node[k].dp[0][1] - node[j].dp[0][1]) * (node[j].dp[1][0] + w(e) * node[j].dp[1][1]) + (node[k].dp[1][1]- node[j].dp[1][1]) * (node[j].dp[0][0] + w * node[j].dp[0][1]))) node[k].sum=e(j,k),j=f∑(node[j].sum+(node[k].dp[0][1]−node[j].dp[0][1])∗(node[j].dp[1][0]+w(e)∗node[j].dp[1][1])+(node[k].dp[1][1]−node[j].dp[1][1])∗(node[j].dp[0][0]+w∗node[j].dp[0][1])))
记得给节点dp数据根据改点颜色初始化!!!
不要被这些可怕的公式给吓倒了。尤其是统计子树中答案的公式,其实并不复杂。
统计答案只需要对于该节点下每一个子节点子树,都统计上整棵子树中除去该子节点子树下所有黑点到该子节点树下所有白点的距离以及反过来所有白点到黑点的距离即可,这可能听起来有些绕,不过这确实值得细品,这里单用语言描述可能会越来越混乱,建议读者手玩一个例子来感受一下。
实在不行,就对着代码品也行QAQ:
#include
#include
using namespace std;
const int N = 1e6 + 50;
const long long mo = 1e9 + 7;
struct Edge{
int point;
int next;
long long w;
}nxt[N];
struct Node{
int code;
long long dp[2][2];
long long sum;
}node[N];
int fa[N];
int head[N];
int T,n,m,tot;
int find(int k){
if(k == fa[k])
return k;
fa[k]= find(fa[k]);
return fa[k];
}
void getMin(int &x,int &y){
if(x > y){
x ^= y;
y ^= x;
x ^= y;
}
}
long long calc(int p){
long long ans = 1;
long long pow = 2;
while(p){
if(p & 1) ans = (ans * pow) % mo;
p >>= 1;
pow = (pow * pow) % mo;
}
return ans;
}
void link(int x,int y,long long w){
nxt[++tot] = {y,head[x],w};
head[x] = tot;
}
void dfs(int k,int f){
node[k].dp[node[k].code][1] = 1;
for(int i = head[k],j;i;i = nxt[i].next){
j = nxt[i].point;
if(j == f) continue;
dfs(j,k);
node[k].dp[0][0] = (node[k].dp[0][0] + node[j].dp[0][0] + (node[j].dp[0][1] * nxt[i].w) % mo) % mo;
node[k].dp[1][0] = (node[k].dp[1][0] + node[j].dp[1][0] + (node[j].dp[1][1] * nxt[i].w) % mo) % mo;
node[k].dp[0][1] += node[j].dp[0][1];
node[k].dp[1][1] += node[j].dp[1][1];
node[k].sum = (node[k].sum + node[j].sum) % mo;
}
long long sum0 = node[k].dp[0][0];
long long cnt0 = node[k].dp[0][1];
long long sum1 = node[k].dp[1][0];
long long cnt1 = node[k].dp[1][1];
long long w;
for(int i = head[k],j;i;i = nxt[i].next){
j = nxt[i].point;
if(j == f) continue;
w = nxt[i].w;
node[k].sum = (node[k].sum + ((cnt0 - node[j].dp[0][1]) * (node[j].dp[1][0] + w * node[j].dp[1][1])) % mo) % mo;
node[k].sum = (node[k].sum + ((cnt1 - node[j].dp[1][1]) * (node[j].dp[0][0] + w * node[j].dp[0][1])) % mo) % mo;
}
}
int main(){
for(cin >> T;T;T--){
tot = 1;
scanf("%d%d",&n,&m);
for(int i = 1;i <= n;i++){
node[i] = {0,0,0,0,0,0};
head[i] = 0;
scanf("%d",&node[i].code);
fa[i] = i;
}
for(int i = 1,x,y;i <= m;i++){
scanf("%d%d",&x,&y);
getMin(x,y);
if(find(x) == find(y))
continue;
link(x,y,calc(i));
link(y,x,calc(i));
fa[fa[y]] = fa[x];
}
dfs(1,0);
cout << node[1].sum << endl;
}
}