【LOJ】#2537. 「PKUWC2018」Minimax-线段树合并

题解

我们从叶子节点逐层递归到根节点。
我们可以建动态开点权值线段树,每个结点上建一颗,递归时考虑合并左右子树的信息。
可以得到,当合并左右儿子到某父结点上时,可以这样转移:
设f[i][j]是在i结点上,权值为j(令权值为j只能从左儿子中转移上来)的概率,则 f[i][j]=f[lc[i]][j]((1p)sum1[rc[i]](i)+psum2[rc[i]](i)) f [ i ] [ j ] = f [ l c [ i ] ] [ j ] ∗ ( ( 1 − p ) ∗ s u m 1 [ r c [ i ] ] ( 所 有 值 大 于 i 的 节 点 上 的 概 率 之 和 ) + p ∗ s u m 2 [ r c [ i ] ] ( 所 有 值 小 于 i 的 节 点 上 的 概 率 之 和 ) )
具体做法见代码,感觉这个实现很巧妙啊(线段树博大精深(跪),我们记一个区间乘标记就万事大吉了啊。


代码

#include
#include
#include
#define ls ch[k][0]
#define rs ch[k][1]
#define mid (((l)+(r))>>1)
using namespace std;
typedef long long ll;
const int N=3e5+10,M=5e6+10,mod=998244353;
int n,son[N][2],rt[N],ch[M][2],p,w[N],rk[N],tot,cnt;
int inv=796898467;ll s[M],cg[M];
struct P{
  int val,id;
  bool operator <(const P&u)const{
     return valint rd()
{
    char ch=getchar();int x=0,f=1;
    while(!isdigit(ch)){if(ch=='-') f=-1;ch=getchar();}
    while(isdigit(ch)){x=x*10+(ch^48);ch=getchar();}
    return x*f;
}
inline void pushdown(int k)
{
    if(cg[k]==1) return;
    s[ls]=s[ls]*cg[k]%mod;s[rs]=s[rs]*cg[k]%mod;
    cg[ls]=cg[ls]*cg[k]%mod;cg[rs]=cg[rs]*cg[k]%mod;
    cg[k]=1;
}
inline void insert(int &k,int l,int r,int pos)
{
    if(!k) k=++tot;s[k]=cg[k]=1;
    if(l==r) return;
    if(pos<=mid) insert(ls,l,mid,pos);
    else insert(rs,mid+1,r,pos);
}

inline int merge(int x,int y,ll sx,ll sy)
{
    if(!x){s[y]=s[y]*sy%mod;cg[y]=cg[y]*sy%mod;return y;}
    if(!y){s[x]=s[x]*sx%mod;cg[x]=cg[x]*sx%mod;return x;}
    pushdown(x);pushdown(y);
    ll A=s[ch[x][0]],B=s[ch[x][1]],C=s[ch[y][0]],D=s[ch[y][1]];
    ch[x][0]=merge(ch[x][0],ch[y][0],(sx+(1-p+mod)*D%mod)%mod,(sy+(1-p+mod)*B%mod)%mod);
    ch[x][1]=merge(ch[x][1],ch[y][1],(sx+p*C%mod)%mod,(sy+p*A%mod)%mod);
    s[x]=(s[ch[x][0]]+s[ch[x][1]])%mod;
    return x;
}

inline int solve(int x)
{
    if(!son[x][0]){insert(rt[x],1,cnt,rk[x]);return rt[x];}
    int l=solve(son[x][0]);
    if(!son[x][1]) return l;
    int r=solve(son[x][1]);
    p=w[x];
    return merge(l,r,0,0);
} 

inline ll cal(int k,int l,int r)
{
    if(l==r){return 1ll*l*t[l].val%mod*s[k]%mod*s[k]%mod;}
    pushdown(k);
    return (cal(ls,l,mid)+cal(rs,mid+1,r))%mod;
}

int main(){
    int i,j,k;
    n=rd();
    for(i=1;i<=n;++i){j=rd();son[j][(son[j][0]!=0)]=i;}
    for(i=1;i<=n;++i){
      w[i]=rd();
      if(son[i][0]) w[i]=1ll*w[i]*inv%mod;//w[i]->1ll*w[i]
      else{t[++cnt].val=w[i];t[cnt].id=i;}
    } 
    sort(t+1,t+cnt+1);
    for(i=1;i<=cnt;++i) rk[t[i].id]=i;
    k=solve(1);
    printf("%lld\n",cal(k,1,cnt));
}

你可能感兴趣的:(妙,线段树可持久化合并)