树链剖分+线段树,每个节点维护以下信息:
(1)单独在某个点分配$i$个人的最大收益。可以$O(m)$合并。
(2)分配$i$个人的最大收益。可以用$O(m^2)$合并。
时间复杂度$O(c(m^2\log n+m\log^2n))$。
#include<cstdio> #include<algorithm> using namespace std; typedef long long ll; const int N=20010,M=51,T=65600; int n,m,q,X=1<<16,Y=~0U>>1,A,B,Q,i,op,x,y; int g[N],nxt[N],f[N],d[N],size[N],son[N],st[N],en[N],top[N],seq[N],dfn; struct P{ ll v[M]; P(){for(int i=0;i<M;i++)v[i]=0;} P operator+(P b){ P c; for(int i=0;i<M;i++)c.v[i]=max(v[i],b.v[i]); return c; } P operator*(P b){ P c; for(int i=0;i<M;i++)for(int j=0;j<M-i;j++)c.v[i+j]=max(c.v[i+j],v[i]+b.v[j]); return c; } }tmp,a[N],v0[T],v1[T],s0,s1; inline int getint(){ A=((A^B)+B/X+B*X)&Y; B=((A^B)+A/X+A*X)&Y; return(A^B)%Q; } inline void gettmp(){ for(int i=1;i<=m;i++)tmp.v[i]=getint(); sort(tmp.v+1,tmp.v+m+1); } void dfs(int x){ size[x]=1; for(int i=g[x];i;i=nxt[i]){ d[i]=d[x]+1,dfs(i),size[x]+=size[i]; if(size[i]>size[son[x]])son[x]=i; } } void dfs2(int x,int y){ seq[st[x]=++dfn]=x;top[x]=y; if(son[x])dfs2(son[x],y); for(int i=g[x];i;i=nxt[i])if(i!=son[x])dfs2(i,i); en[x]=dfn; } inline void up(int x){ v0[x]=v0[x<<1]+v0[x<<1|1]; v1[x]=v1[x<<1]*v1[x<<1|1]; } void build(int x,int a,int b){ if(a==b){v0[x]=v1[x]=::a[seq[a]];return;} int mid=(a+b)>>1; build(x<<1,a,mid),build(x<<1|1,mid+1,b),up(x); } void change(int x,int a,int b,int c){ if(a==b){v0[x]=v1[x]=tmp;return;} int mid=(a+b)>>1; if(c<=mid)change(x<<1,a,mid,c);else change(x<<1|1,mid+1,b,c); up(x); } void ask0(int x,int a,int b,int c,int d){ if(c<=a&&b<=d){s0=s0+v0[x];return;} int mid=(a+b)>>1; if(c<=mid)ask0(x<<1,a,mid,c,d); if(d>mid)ask0(x<<1|1,mid+1,b,c,d); } void ask1(int x,int a,int b,int c,int d){ if(c<=a&&b<=d){s1=s1*v1[x];return;} int mid=(a+b)>>1; if(c<=mid)ask1(x<<1,a,mid,c,d); if(d>mid)ask1(x<<1|1,mid+1,b,c,d); } inline void chain(int x,int y){ if(x==y)return; x=f[x]; while(top[x]!=top[y])ask0(1,1,n,st[top[x]],st[x]),x=f[top[x]]; ask0(1,1,n,st[y],st[x]); } int main(){ scanf("%d%d%d%d%d",&n,&m,&A,&B,&Q); for(i=2;i<=n;i++)scanf("%d",&f[i]),nxt[i]=g[f[i]],g[f[i]]=i; for(i=1;i<=n;i++)gettmp(),a[i]=tmp; dfs(1),dfs2(1,1),build(1,1,n); scanf("%d",&q); while(q--){ scanf("%d%d",&op,&x); if(!op)gettmp(),change(1,1,n,st[x]); else{ scanf("%d",&y); s0=s1=P(); chain(x,y),ask1(1,1,n,st[x],en[x]); s0=s0*s1; printf("%lld\n",s0.v[m]); } } return 0; }