直接线段树合并维护选到每个数的概率,合并时先左再右,顺便维护某个点比它小的概率和,区间修改时打标记即可。
#include
#include
#include
#include
using namespace std;
#define LL long long
const int Maxn=300010;
const int mod=998244353;
int n,V[Maxn],P[Maxn],to[Maxn];
bool mark[Maxn];
struct Edge{int y,next;}e[Maxn];
int last[Maxn],len=0;
void ins(int x,int y)
{
int t=++len;
e[t].y=y;e[t].next=last[x];last[x]=t;
}
struct Node{int x,id;}A[Maxn];int la=0;
bool cmp(Node a,Node b){return a.xint root[Maxn],lc[Maxn*20],rc[Maxn*20],c[Maxn*20],tot=0,tag[Maxn*20];
int down(int x)
{
if(tag[x]!=1)
{
int t=tag[x];
tag[lc[x]]=(LL)tag[lc[x]]*t%mod;
c[lc[x]]=(LL)c[lc[x]]*t%mod;
tag[rc[x]]=(LL)tag[rc[x]]*t%mod;
c[rc[x]]=(LL)c[rc[x]]*t%mod;
tag[x]=1;
}
}
int up(int x){c[x]=(c[lc[x]]+c[rc[x]])%mod;}
void insert(int &u,int l,int r,int p)
{
if(!u)u=++tot;
tag[u]=1;
if(l==r){c[u]=1;return;}
int mid=l+r>>1;
if(p<=mid)insert(lc[u],l,mid,p);
else insert(rc[u],mid+1,r,p);
up(u);
}
int p1,p2;
void merge(int &u1,int u2,int px)
{
if(u1)down(u1);if(u2)down(u2);
if(!u1)
{
u1=u2;
int tmp=((LL)p1*px%mod+(LL)(1-p1+mod)%mod*(1-px+mod)%mod)%mod;
tag[u1]=(LL)tag[u1]*tmp%mod;
p2=(p2+c[u2])%mod;
c[u1]=(LL)c[u1]*tmp%mod;
return;
}
if(!u2)
{
int tmp=((LL)p2*px%mod+(LL)(1-p2+mod)%mod*(1-px+mod)%mod)%mod;
tag[u1]=(LL)tag[u1]*tmp%mod;
p1=(p1+c[u1])%mod;
c[u1]=(LL)c[u1]*tmp%mod;
return;
}
merge(lc[u1],lc[u2],px);
merge(rc[u1],rc[u2],px);
up(u1);
}
int ans=0;
void work(int x,int l,int r)
{
if(l==r)
{
ans=(ans+(LL)l*to[l]%mod*c[x]%mod*c[x]%mod)%mod;
return;
}
down(x);
int mid=l+r>>1;
work(lc[x],l,mid);work(rc[x],mid+1,r);
}
void dfs(int x)
{
int s1=-1,s2=-1;
for(int i=last[x];i;i=e[i].next)
{
int y=e[i].y;
dfs(y);
if(s1==-1)s1=y;
else s2=y;
}
if(s1==-1)insert(root[x],1,la,V[x]);
else if(s2==-1)root[x]=root[s1];
else
{
root[x]=root[s1];
p1=p2=0;
merge(root[x],root[s2],P[x]);
}
}
int Pow(int x,int y)
{
if(!y)return 1;
if(y==1)return x;
int t=Pow(x,y>>1),re=(LL)t*t%mod;
if(y&1)re=(LL)re*x%mod;
return re;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
int fa;
scanf("%d",&fa);
if(fa)ins(fa,i),mark[fa]=true;
}
for(int i=1;i<=n;i++)
{
if(mark[i])scanf("%d",&P[i]),P[i]=(LL)Pow(10000,mod-2)*P[i]%mod;
else scanf("%d",&A[++la].x),A[la].id=i;
}
sort(A+1,A+1+la,cmp);
for(int i=1;i<=la;i++)V[A[i].id]=i,to[i]=A[i].x;
dfs(1);
work(root[1],1,la);
printf("%d",ans);
}