有一棵有 n n n个节点的树,每条边有一个边权 w w w。有 m m m个特殊点,将这些点记为集合 A A A。
将 A A A中的元素随机打乱得到序列 a a a,求 ∑ i = 2 m d ( a i − 1 , a i ) \sum\limits_{i=2}^md(a_{i-1},a_i) i=2∑md(ai−1,ai)的期望值模 998244353 998244353 998244353后的值,其中 d ( x , y ) d(x,y) d(x,y)表示 x x x到 y y y的边权和。
有 q q q次修改,每次修改会将与 x x x相连的边的权值增加 k k k。求每次修改后上述式子的期望值。
1 ≤ n ≤ 5 × 1 0 5 , m ≤ n , 1 ≤ q ≤ 5 × 1 0 5 1\leq n\leq 5\times 10^5,m\leq n,1\leq q\leq 5\times 10^5 1≤n≤5×105,m≤n,1≤q≤5×105
1 ≤ w , k ≤ 1 0 9 1\leq w,k\leq 10^9 1≤w,k≤109
对于每组特殊点 x , y x,y x,y,我们考虑有多少种方案会计算到 d ( x , y ) d(x,y) d(x,y)的贡献。在确定 x , y x,y x,y在 a a a中相邻之后,其他 m − 2 m-2 m−2个数有 ( m − 2 ) ! (m-2)! (m−2)!种放法, x , y x,y x,y中较前的数可以放在第一个到第 m − 1 m-1 m−1个位置上,确定了前一个数,则后一个数也确定了,而这两个数的顺序可以为 x , y x,y x,y或者 y , x y,x y,x,所以还要乘 2 2 2,也就是说有 2 ( m − 2 ) ! × ( m − 1 ) = 2 ( m − 1 ) ! 2(m-2)!\times (m-1)=2(m-1)! 2(m−2)!×(m−1)=2(m−1)!种方案会计算到 d ( x , y ) d(x,y) d(x,y)的贡献。而题目要求的是期望值,总共有 m ! m! m!种方案,那么 d ( x , y ) d(x,y) d(x,y)对答案的贡献为 2 ( m − 1 ) ! m ! × d ( x , y ) = 2 m × d ( x , y ) \dfrac{2(m-1)!}{m!}\times d(x,y)=\dfrac 2m\times d(x,y) m!2(m−1)!×d(x,y)=m2×d(x,y)。
下面,我们要求每条边被多少 d ( x , y ) d(x,y) d(x,y)计算过,这用一个 d f s dfs dfs即可算出,记这个值为 t d i td_i tdi。然后,求出所有边 i i i的 w i w_i wi与 t d i td_i tdi之积的和,也就是 ∑ i w i × t d i \sum\limits_iw_i\times td_i i∑wi×tdi, m 2 × ∑ i w i × t d i \dfrac m2\times \sum\limits_iw_i\times td_i 2m×i∑wi×tdi即为答案。
我们考虑每次修改对答案的贡献。设与 i i i相连的边的 t d td td值之和为 t w i tw_i twi,则每次修改会让 ∑ i w i × t d i \sum\limits_iw_i\times td_i i∑wi×tdi增加 k × t w i k\times tw_i k×twi。那么,我们可以 O ( 1 ) O(1) O(1)修改。因为题目只需要求答案,所以我们不需要真的去修改 w i w_i wi。
时间复杂度为 O ( n + q ) O(n+q) O(n+q)。
#include
using namespace std;
const int N=500000;
const long long mod=998244353;
int n,m,q,z[N+5],siz[N+5];
long long ans,pt,w[N+5],td[N+5],tw[N+5];
vector<pair<int,int>>g[N+5];
long long mi(long long t,long long v){
if(!v) return 1;
long long re=mi(t,v/2);
re=re*re%mod;
if(v&1) re=re*t%mod;
return re;
}
void dfs(int u,int fa){
siz[u]=z[u];
for(auto p:g[u]){
int v=p.first,id=p.second;
if(v==fa) continue;
dfs(v,u);
siz[u]+=siz[v];
td[id]=1ll*(m-siz[v])*siz[v]%mod;
}
}
int main()
{
// freopen("sakuya.in","r",stdin);
// freopen("sakuya.out","w",stdout);
scanf("%d%d",&n,&m);
for(int i=1,x,y;i<n;i++){
scanf("%d%d%lld",&x,&y,&w[i]);
g[x].push_back({y,i});
g[y].push_back({x,i});
}
for(int i=1,x;i<=m;i++){
scanf("%d",&x);
z[x]=1;
}
dfs(1,0);
for(int i=1;i<n;i++){
ans=(ans+td[i]*w[i])%mod;
}
for(int i=1;i<=n;i++){
for(auto p:g[i]){
tw[i]=(tw[i]+td[p.second])%mod;
}
}
scanf("%d",&q);
long long tq=mi(m,mod-2)*2%mod;
for(int o=1,x,k;o<=q;o++){
scanf("%d%d",&x,&k);
ans=(ans+tw[x]*k)%mod;
pt=ans*tq%mod;
printf("%lld\n",pt);
}
return 0;
}