这个题的 n^2 dp 是很显然的 线段树优化dp 也是很显然的
这个题的价值在于增加线段树合并技能熟练度
#include #include #include #include #include #include #include #include #include #include #include #include #include #include using namespace std; typedef double db; typedef long long ll; typedef unsigned int uint; inline int read() { int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch<='9'&&ch>='0'){x=10*x+ch-'0';ch=getchar();} return x*f; } void print(int x) {if(x<0)putchar('-'),x=-x;if(x>=10)print(x/10);putchar(x%10+'0');} const int N=300100,mod=998244353; inline int qpow(int x,int y) { int res(1); while(y) { if(y&1) res=(ll)x*res%mod; x=(ll)x*x%mod; y>>=1; } return res; } int ch[N][2]; int tot; int V[N],P[N]; struct president_tree{int ls,rs,sum,tag;}tr[N*30]; int root[N],sz; inline void pushdown(int k) { if(tr[k].tag) { int tag=tr[k].tag,ls=tr[k].ls,rs=tr[k].rs; tr[ls].sum=(ll)tr[ls].sum*tag%mod, tr[rs].sum=(ll)tr[rs].sum*tag%mod, tr[ls].tag=(ll)tr[ls].tag*tag%mod, tr[rs].tag=(ll)tr[rs].tag*tag%mod, tr[k].tag=1; } } void insert(int &k,int l,int r,int x,int val) { k=++sz; tr[k].tag=1; tr[k].sum+=val; if(l==r) return ; int mid=(l+r)>>1; x<=mid ? insert(tr[k].ls,l,mid,x,val) : insert(tr[k].rs,mid+1,r,x,val); } int merger(int x,int y,int sum_x,int sum_y,int u) { if(!x) { tr[y].sum=(ll)tr[y].sum*sum_x%mod, tr[y].tag=(ll)tr[y].tag*sum_x%mod; return y; } if(!y) { tr[x].sum=(ll)tr[x].sum*sum_y%mod, tr[x].tag=(ll)tr[x].tag*sum_y%mod; return x; } int val_x[2],val_y[2]; pushdown(x),pushdown(y); val_x[0]=tr[tr[x].ls].sum, val_x[1]=tr[tr[x].rs].sum, val_y[0]=tr[tr[y].ls].sum, val_y[1]=tr[tr[y].rs].sum; tr[x].ls=merger(tr[x].ls,tr[y].ls,(sum_x+(ll)val_x[1]*(1+mod-P[u]))%mod,(sum_y+(ll)val_y[1]*(1+mod-P[u]))%mod,u); tr[x].rs=merger(tr[x].rs,tr[y].rs,(sum_x+(ll)val_x[0]*P[u])%mod,(sum_y+(ll)val_y[0]*P[u])%mod,u); tr[x].sum=(tr[tr[x].ls].sum+tr[tr[x].rs].sum)%mod; return x; } void dfs(int u) { if(!u) return ; if(!ch[u][0]) { insert(root[u],1,tot,lower_bound(V+1,V+1+tot,P[u])-V,1); return ; } dfs(ch[u][0]),dfs(ch[u][1]); if(!ch[u][1]) root[u]=root[ch[u][0]]; else root[u]=merger(root[ch[u][0]],root[ch[u][1]],0,0,u); } int ans(0); void cal(int k,int l,int r) { if(l==r) { (ans+=(ll)l*V[l]%mod*tr[k].sum%mod*tr[k].sum%mod)%=mod; return ; } pushdown(k); int mid=(l+r)>>1; cal(tr[k].ls,l,mid),cal(tr[k].rs,mid+1,r); } int main() { int n=read(); register int i,x; for(i=1;i<=n;++i) x=read(),ch[x][ch[x][0] ? 1 : 0]=i; int inv_w=qpow(10000,mod-2); for(i=1;i<=n;++i) P[i]=read(), ch[i][0] ? P[i]=(ll)P[i]*inv_w%mod : V[++tot]=P[i]; sort(V+1,V+1+tot); dfs(1); cal(root[1],1,tot); cout<