题目:给出一棵树,边权在 [ 0 , 1 ] [0,1] [0,1]间随机,求直径长度的期望值。
1.枚举直径中点在那条边上。
直径中点落在点上的几率十分小,可以看做0.
所以我们枚举每一条边求直径中点在这条边上时的(概率*此时直径长度的期望)。
2.求出答案我们需要维护每个点到子树内的最长链的长度分布情况,
这个是个连续的量,我们需要用一个函数来描述他:
f u ( x ) = P r [ d ( u ) ≤ x ] f_u(x) = Pr[d(u)\leq x] fu(x)=Pr[d(u)≤x],
比如,对于1->2->3这样一个简单的两个点的树。
f 2 ( x ) = x , f 3 ( x ) = 1 f_2(x) = x , f_3(x) = 1 f2(x)=x,f3(x)=1
f 1 ( x ) f_1(x) f1(x)分成了两段,
在 x ∈ [ 0 , 1 ] x\in[0,1] x∈[0,1], f 1 ( x ) = x 2 2 f_1(x) = \frac {x^2}2 f1(x)=2x2
在 x ∈ [ 1 , 2 ] x\in[1,2] x∈[1,2], f 1 ( x ) = 2 x − x 2 2 − 1 f_1(x) = 2x-\frac {x^2}2-1 f1(x)=2x−2x2−1
所以简单的说我们的树形 D P DP DP中要维护的答案是一个个分段函数。
要如何转移呢?
需要两个操作,把儿子的函数加上一条边后取最长链即取 max \max max和加上一条边。
取 max \max max很简单, P r [ m a x ( a , b ) ≤ x ] = P r [ a ≤ x ] P r [ b ≤ x ] Pr[max(a,b)\leq x] = Pr[a\leq x] Pr[b\leq x] Pr[max(a,b)≤x]=Pr[a≤x]Pr[b≤x]
就是两个函数乘起来。(注意所有的函数都是在整数处分段,这为我们写代码提供了便利)。
加上一条边很好想出来: f ( x ) = ∫ x − 1 x g ( y ) d y f(x) = \int_{x-1}^{x}g(y)\rm{d} y f(x)=∫x−1xg(y)dy
表示前面的和 ≤ y \leq y ≤y,加入一个长为 x − y x-y x−y的边。
注意这虽然是一个定积分,但是其上下界与 x x x相关,所以其实是关于 x x x的一个函数,而通过定积分与 y y y相关的部分就不在了,所以,实际上的操作是把 g g g的每一段都不定积分后得到 h ( x ) h(x) h(x),对每一段的 f ( x ) = h ( x ) − h ( x − 1 ) f(x) = h(x) - h(x-1) f(x)=h(x)−h(x−1),注意这里 h ( x ) h(x) h(x)与 h ( x − 1 ) h(x-1) h(x−1)是一个分段函数中不同的两段。
然后我们还要加入新的一段,因为定义域被拓展了 1 1 1。但是对于最后的一段此时只存在 h ( x − 1 ) h(x-1) h(x−1),而 h ( x ) h(x) h(x)是无定义的,发现最后一段一定是 f ( p ) = p − h ( p − 1 ) + 常 数 f(p) = p-h(p-1)+常数 f(p)=p−h(p−1)+常数,稍微思考一下即可得出常数的计算方法。
有了这两个操作我们就可以通过树形 D P DP DP求出对于一条边 ( u , v ) (u,v) (u,v), u u u往这条边的反方向的概率分布函数, v v v同理。
3.如何通过 f u , f v f_u,f_v fu,fv求出中点在这条边上的概率*长度期望?
首先需要注意, u u u或 v v v没有儿子的情况是需要特殊考虑的,因为我们默认分段函数的每一段都是一个段而不是一个点,所以没有儿子定义域就只有一个点那么就很烦。
考虑如何计算答案,如果u,v其中之一是叶子,不妨认为是v ,那么问题变成了: l ≥ d ( u ) l\geq d(u) l≥d(u) ,求 E ( l + d ( u ) ) E(l+d(u)) E(l+d(u))。由期望的线性性,可以对两部分分别算贡献,可以得到:
E ( l ) = ∫ 0 1 l f u ( l ) d l E(l) = \int_0^1 lf_u(l){\rm d}l E(l)=∫01lfu(l)dl
表示枚举长度 l l l,另一边 d ( u ) ≤ l d(u)\leq l d(u)≤l的概率是 f u ( l ) f_u(l) fu(l)
E ( d ( u ) ) = ∫ 0 1 ( ∫ 0 l x f u ′ ( x ) d x ) d l E(d(u)) = \int_0^1(\int_0^l xf'_u(x)dx){\rm d}l E(d(u))=∫01(∫0lxfu′(x)dx)dl
这里用到了求导,这样可以微分求出长度 = x =x =x的概率,然后再 × x \times x ×x,再外层枚举 l l l。
这些操作虽然听起来很奇怪但是他们都是可以写出来的。(还很短。)
一般情况也这么推导即可。
A C C o d e \mathrm{AC \ Code} AC Code
#include
#define maxn 105
#define mod 998244353
#define rep(i,j,k) for(int i=(j),LIM=(k);i<=LIM;i++)
#define per(i,j,k) for(int i=(j),LIM=(k);i>=LIM;i--)
#define Ct const
#define pb push_back
using namespace std;
typedef vector<int> Poly;
typedef vector<Poly> Func;
int add(int b,int c){ return (1ll*b+c)%mod;/*int t=b+c;return t>=mod?t-mod:t;*/ }
int sub(int b,int c){ return (1ll*b-c)%mod;/*int t=b-c;return t<0?t+mod:t;*/ }
void inc(int &a,int c){ a = (1ll*a+c)%mod; /*((a+=c)>=mod)&&(a-=mod);*/ }
int n,inv[maxn]={1,1},C[maxn][maxn];
int info[maxn],Prev[maxn<<1],to[maxn<<1],xu[maxn],xv[maxn],cnt_e=1;
void Node(int u,int v){ Prev[++cnt_e]=info[u],info[u]=cnt_e,to[cnt_e]=v; }
Poly X;
Func F[maxn*3];
Poly operator +(Ct Poly &A,Ct Poly &B){
Poly r(max(A.size() , B.size()));
rep(i,0,r.size()-1) r[i] = add(i<A.size()?A[i]:0,i<B.size()?B[i]:0);
return r;
}
Poly operator -(Ct Poly &A,Ct Poly &B){
Poly r(max(A.size() , B.size()));
rep(i,0,r.size()-1) r[i] = sub(i<A.size()?A[i]:0,i<B.size()?B[i]:0);
return r;
}
Poly operator *(Ct Poly &A,Ct Poly &B){
Poly r(A.size()+B.size()-1);
rep(i,0,A.size()-1) rep(j,0,B.size()-1) inc(r[i+j],1ll*A[i]*B[j]%mod);
return r;
}
Poly INT(Ct Poly &A){
Poly r(A.size()+1);
rep(i,1,A.size()) r[i]=1ll*A[i-1]*inv[i]%mod;
return r;
}
Poly Shift(Ct Poly &A,int x){
if(A.empty()) return Poly();
Poly B(A.size()),r(A.size());
rep(i,B[0]=1,B.size()-1) B[i]=1ll*B[i-1]*x%mod;
rep(i,0,B.size()-1) rep(j,0,i)
inc(r[i-j],1ll*A[i]*B[j]%mod*C[i][j]%mod);
return r;
}
Poly DER(Ct Poly &A){
Poly r(A.size()-1);
rep(i,1,A.size()-1) r[i-1]=1ll*A[i]*i%mod;
return r;
}
void Print(Ct Poly &A){
if(A.empty()) puts("zero");
rep(i,0,A.size()-1) printf("%d%c",A[i]," \n"[i==A.size()-1]);
}
void Print(Ct Func &A){
puts("");
rep(i,0,A.size()-1) Print(A[i]);
puts("");
}
int Eval(Ct Poly &A,int x){
int pw = 1, r = 0;
rep(i,0,A.size()-1) inc(r , 1ll * pw * A[i] % mod) , pw = 1ll * x * pw % mod;
return r;
}
Func operator -(Ct Func &A,Ct Func &B){
Func r(max(A.size() , B.size()));
rep(i,0,r.size()-1) r[i] = (i<A.size()?A[i]:Poly())-(i<B.size()?B[i]:Poly());
return r;
}
Func operator +(Ct Func &A,Ct Func &B){
Func r(max(A.size() , B.size()));
rep(i,0,r.size()-1) r[i] = (i<A.size()?A[i]:Poly())+(i<B.size()?B[i]:Poly());
return r;
}
Func operator *(Ct Func &A,Ct Func &B){
Func r(max(A.size(),B.size()));
rep(i,0,r.size()-1) r[i] = i>=A.size()?B[i]:i>=B.size()?A[i]:A[i]*B[i];
return r;
}
Func operator *(Ct Func &A,Ct Poly &B){
Func r(A.size());
rep(i,0,r.size()-1) r[i] = A[i] * B;
return r;
}
Func INT(Ct Func &A){
Func r(A.size());
rep(i,0,A.size()-1){
r[i] = INT(A[i]);
inc(r[i][0] , sub(i?Eval(r[i-1],i):0,Eval(r[i],i)));
}
return r;
}
Func Shift(Ct Func &A,int x){
Func r;
for(int i=0;i+x<(int)A.size();i++)
r.push_back(i+x>=0?Shift(A[i+x],x):Poly());
return r;
}
Func DER(Ct Func &A){
Func r(A.size());
rep(i,0,A.size()-1) r[i] = DER(A[i]);
return r;
}
int Eval(Ct Func &A,int x){
if(A.empty()) return 1;
return Eval(A[max(0,min((int)A.size()-1,x))],x);
}
Func dfs(int u,int fr){
Func &pu = F[fr];
if(pu.empty())
for(int i=info[u],v;i;i=Prev[i]) if(i^1^fr){
Func pv = dfs(v = to[i] , i);
int L = pv.size();
Func a = INT(pv) , b = Shift(a,-1) , c(pv.size() + 1);
if(pv.empty()) c = Func(1,X);
else{
rep(j,0,L-1) c[j] = a[j] - b[j];
c[L].pb(add(mod-L,Eval(a,L))),c[L].pb(1);
c[L] = c[L] - b[L];
}
pu = pu * c;
}
return pu;
}
int calc(Func A){
int r = 0;
Func a = INT(A*X);
inc(r,Eval(a,1));
a = INT(INT(DER(A)*X));
inc(r,Eval(a,1));
return r;
}
int calc(Func A,Func B){
int r = 0 , L = max(A.size() , B.size())+1 , lena = A.size() , lenb = B.size();
for(;A.size() < L;) A.pb(Poly(1,1));
for(;B.size() < L;) B.pb(Poly(1,1));
Func a;
a = INT(((Shift(B,1) - B) * Poly(2,1)- Shift(INT(DER(B) * X),1) + INT(DER(B) * X)) * DER(A) * X);
inc(r,Eval(a,lena));
a = INT((INT(DER(A)*Poly(2,1))-Shift(INT(DER(A)*Poly(2,1)),-1)-(A-Shift(A,-1))*X) * X * DER(B));
inc(r,Eval(a,lenb));
a = INT((Shift(INT(B*X),1)-INT(B*X)+INT(B)*X-Shift(INT(B),1)*X-B*Poly(1,inv[2]))*DER(A));
inc(r,Eval(a,lena));
return r;
}
int solve(int id){
Func a = dfs(xu[id],id<<1|1) , b = dfs(xv[id],id<<1);
if(a.empty()){
if(b.empty()) return inv[2];
else return calc(b);
}
else if(b.empty()) return calc(a);
else return calc(a,b) + calc(b,a);
}
int main(){
freopen("expectation.in","r",stdin);
freopen("expectation.out","w",stdout);
X.pb(0),X.pb(1);
scanf("%d",&n);
rep(i,C[0][0]=1,maxn-1) rep(j,C[i][0]=1,i) C[i][j]=add(C[i-1][j-1],C[i-1][j]);
rep(i,2,maxn-1) inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
rep(i,1,n-1){
scanf("%d%d",&xu[i],&xv[i]);
Node(xu[i],xv[i]),Node(xv[i],xu[i]);
}
int ans = 0;
rep(i,1,n-1) inc(ans,solve(i));
printf("%d\n",(ans+mod)%mod);
}