首先发现每个叶子节点的权值都有可能成为最终根节点的权值,观察题目中给出的式子:
\[ \sum_{i=1}^m i V_i D_i^2 \]
发现只要算出每个权值被根节点取到的概率 \(D_i\),然后就能计算答案。
设 \(f_{x,i}\) 为节点 \(x\) 取到第 \(i\) 小权值的概率,根据是从左儿子还是从右儿子取到的权值来进行分类讨论:
\[f_{x,i} = f_{ls,i}(p_x \sum_{j=1}^{i-1}f_{rs,j}+(1-p_x) \sum_{j=i+1}^m f_{rs,j}) + f_{rs,i}(p_x \sum_{j=1}^{i-1}f_{ls,j}+(1-p_x) \sum_{j=i+1}^m f_{ls,j}) \]
直接进行转移的话,时间和空间都无法接受,由本题为树形 \(DP\) 和转移方程的特点得,可以通过整体 \(DP\) 来优化。
用动态开点线段树来维护第二维,通过线段树合并实现从左右儿子转移。当一个节点没有儿子时就直接赋初值,只有一个儿子就直接转移。线段树合并时记录前缀和和后缀和,来计算左右区间互相的贡献,线段树上区间乘,打标记即可。
\(code:\)
#include
#define maxn 600010
#define maxm 10000010
#define p 998244353
#define inv 796898467
#define mid ((l+r)>>1)
using namespace std;
typedef long long ll;
template inline void read(T &x)
{
x=0;char c=getchar();bool flag=false;
while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
if(flag)x=-x;
}
int n,tot,cnt;
ll ans;
int ch[maxn][2],rt[maxn],ls[maxm],rs[maxm];
ll val[maxn],s[maxn],d[maxn],sum[maxm],tag[maxm];
void pushtag(int x,ll v)
{
sum[x]=sum[x]*v%p,tag[x]=tag[x]*v%p;
}
void pushdown(int x)
{
if(tag[x]==1) return;
pushtag(ls[x],tag[x]),pushtag(rs[x],tag[x]),tag[x]=1;
}
void insert(int l,int r,int pos,int &cur)
{
if(!cur) cur=++tot,tag[cur]=1;
sum[cur]++;
if(l==r) return;
if(pos<=mid) insert(l,mid,pos,ls[cur]);
else insert(mid+1,r,pos,rs[cur]);
}
int merge(int x,int y,ll lx,ll rx,ll ly,ll ry,ll v)
{
if(!x&&!y) return 0;
pushdown(x),pushdown(y);
if(x&&!y)
{
pushtag(x,(v*ly%p+(1-v+p)*ry%p)%p);
return x;
}
if(!x&&y)
{
pushtag(y,(v*lx%p+(1-v+p)*rx%p)%p);
return y;
}
ll px=sum[ls[x]],sx=sum[rs[x]],py=sum[ls[y]],sy=sum[rs[y]];
ls[x]=merge(ls[x],ls[y],lx,(rx+sx)%p,ly,(ry+sy)%p,v);
rs[x]=merge(rs[x],rs[y],(lx+px)%p,rx,(ly+py)%p,ry,v);
sum[x]=(sum[ls[x]]+sum[rs[x]])%p;
return x;
}
void solve(int x)
{
if(!ch[x][0])
{
insert(1,cnt,val[x],rt[x]);
return;
}
if(!ch[x][1])
{
solve(ch[x][0]),rt[x]=rt[ch[x][0]];
return;
}
solve(ch[x][0]),solve(ch[x][1]);
rt[x]=merge(rt[ch[x][0]],rt[ch[x][1]],0,0,0,0,val[x]);
}
void dfs(int l,int r,int cur)
{
if(l==r)
{
d[l]=sum[cur];
return;
}
pushdown(cur),dfs(l,mid,ls[cur]),dfs(mid+1,r,rs[cur]);
}
int main()
{
read(n);
for(int i=1;i<=n;++i)
{
int fa;
read(fa);
if(ch[fa][0]) ch[fa][1]=i;
else ch[fa][0]=i;
}
for(int i=1;i<=n;++i)
{
read(val[i]);
if(ch[i][0]) val[i]=val[i]*inv%p;
else s[++cnt]=val[i];
}
sort(s+1,s+cnt+1);
for(int i=1;i<=n;++i)
if(!ch[i][0])
val[i]=lower_bound(s+1,s+cnt+1,val[i])-s;
solve(1),dfs(1,cnt,rt[1]);
for(int i=1;i<=cnt;++i)
ans=(ans+i*s[i]%p*d[i]%p*d[i]%p)%p;
printf("%lld",ans);
return 0;
}