f [ x ] [ 0 / 1 ] f[x][0/1] f[x][0/1]表示不选/选 x x x, x x x的子树中的最大带权独立集, y y y是 x x x的儿子。
f [ x ] [ 1 ] = v [ x ] + ∑ f [ y ] [ 0 ] f [ x ] [ 0 ] = ∑ m a x ( f [ y ] [ 0 ] , f [ y ] [ 1 ] ) f[x][1]=v[x]+\sum f[y][0] \\f[x][0]=\sum max(f[y][0],f[y][1]) f[x][1]=v[x]+∑f[y][0]f[x][0]=∑max(f[y][0],f[y][1])
树剖后, y是x的轻儿子,设
g 1 [ x ] = v [ x ] + ∑ f [ y ] [ 0 ] g 2 [ x ] = ∑ m a x ( f [ y ] [ 0 ] , f [ y ] [ 1 ] ) g_1[x]=v[x]+\sum f[y][0] \\g_2[x]=\sum max(f[y][0],f[y][1]) g1[x]=v[x]+∑f[y][0]g2[x]=∑max(f[y][0],f[y][1])
f [ x ] [ 1 ] = g 1 [ x ] + f [ m s o n ] [ 0 ] f [ x ] [ 0 ] = g 2 [ x ] + m a x ( f [ m s o n ] [ 0 ] , f [ m s o n ] [ 1 ] ) f[x][1]=g_1[x]+f[mson][0] \\f[x][0]=g_2[x]+max(f[mson][0],f[mson][1]) f[x][1]=g1[x]+f[mson][0]f[x][0]=g2[x]+max(f[mson][0],f[mson][1])
写成矩阵(把*换成+,+换成取min/max,仍满足矩阵的性质)
[ f [ x ] [ 0 ] f [ x ] [ 1 ] ] = [ g 2 [ x ] g 2 [ x ] g 1 [ x ] − i n f ] [ f [ m s o n ] [ 0 ] f [ m s o n ] [ 1 ] ] \begin{bmatrix} f[x][0] \\ f[x][1] \end{bmatrix}= \begin{bmatrix} g_2[x] & g_2[x]\\ g_1[x] & -inf \end{bmatrix} \begin{bmatrix} f[mson][0] \\ f[mson][1] \end{bmatrix} [f[x][0]f[x][1]]=[g2[x]g1[x]g2[x]−inf][f[mson][0]f[mson][1]]
线段树维护重链上的矩阵就好了。
然鹅这样是两个log,十分不优秀。
发现对整棵树建线段树非常没有必要,我们只需要维护每条重链上的信息。如果用splay维护每条重链,lct维护整棵树可以做到log的复杂度,然鹅lct自带巨无霸常数。
而这里我们并不需要动态的数据结构,lct大材小用了,可以考虑给每根重链建一棵bst。
每次找到重链的带权重心(权重为轻儿子的sz+1)作为根,然后递归建左右子树。这样发现经过无论bst上的边还是轻边,sz都至少减小一半,于是最多只会经过log条边。
用这棵可爱的bst维护矩阵就好了。
luogu P4751 动态dp【加强版】
//Achen
#include
#define For(i,a,b) for(register int i=(a);i<=(b);i++)
#define Rep(i,a,b) for(register int i=(a);i>=(b);i--)
const int N=1e6+7;
using namespace std;
typedef long long LL;
typedef double db;
const int inf=1e9;
int n,m,val[N];
template<typename T>void read(T &x) {
char ch=getchar(); x=0; T f=1;
while(ch!='-'&&(ch<'0'||ch>'9')) ch=getchar();
if(ch=='-') f=-1,ch=getchar();
for(;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0'; x*=f;
}
struct jz {
int a[2][2];
friend jz operator *(const jz&A,const jz&B) {
jz rs;
For(i,0,1) For(j,0,1) {
if(A.a[i][0]+B.a[0][j]>A.a[i][1]+B.a[1][j])
rs.a[i][j]=A.a[i][0]+B.a[0][j];
else rs.a[i][j]=A.a[i][1]+B.a[1][j];
}
return rs;
}
int f0() { return a[0][0]; }
int f1() { return a[1][0]; }
}dt[N],sum[N];
int ecnt,fir[N],nxt[N<<1],to[N<<1];
void add(int u,int v) {
nxt[++ecnt]=fir[u]; fir[u]=ecnt; to[ecnt]=v;
nxt[++ecnt]=fir[v]; fir[v]=ecnt; to[ecnt]=u;
}
int p[N],sz[N],nsz[N],mson[N];
void dfs1(int x,int fa) {
sz[x]=1;
for(int i=fir[x];i;i=nxt[i]) if(to[i]!=fa) {
dfs1(to[i],x);
sz[x]+=sz[to[i]];
if(!mson[x]||sz[to[i]]>sz[mson[x]]) mson[x]=to[i];
}
nsz[x]=sz[x]-sz[mson[x]];
}
int ch[N][2];
#define lc ch[x][0]
#define rc ch[x][1]
int isroot(int x) { return ch[p[x]][0]!=x&&ch[p[x]][1]!=x; }
inline void upd(int x) {
if(lc) sum[x]=sum[lc]*dt[x]; else sum[x]=dt[x];
if(rc) sum[x]=sum[x]*sum[rc];
}
int sta[N],top;
int build(int l,int r) {
int tot=0,ntot=0;
For(i,l,r) tot+=nsz[sta[i]];
For(i,l,r) {
ntot+=nsz[sta[i]];
if(ntot*2>=tot) {
int x=sta[i];
lc=build(l,i-1); if(lc) p[lc]=x;
rc=build(i+1,r); if(rc) p[rc]=x;
upd(x); return x;
}
} return 0;
}
int RT;
int dfs2(int x) {
for(int y=x;y;y=mson[y]) {
dt[y]=(jz){0,0,val[y],0};
for(int i=fir[y];i;i=nxt[i]) if(sz[to[i]]<sz[y]&&to[i]!=mson[y]) {
int z=dfs2(to[i]); p[z]=y;
int t=max(sum[z].f0(),sum[z].f1());
dt[y].a[0][0]+=t; dt[y].a[0][1]+=t;
dt[y].a[1][0]+=sum[z].f0();
}
}
top=0;
for(int i=x;i;i=mson[i]) sta[++top]=i;
int rs=build(1,top);
return rs;
}
int lastans;
inline void change(int x,int vl) {
dt[x].a[1][0]+=(vl-val[x]); val[x]=vl;
while(x!=RT) {
if(isroot(x)) {
int t=max(sum[x].f0(),sum[x].f1());
dt[p[x]].a[0][0]-=t; dt[p[x]].a[0][1]-=t;
dt[p[x]].a[1][0]-=sum[x].f0();
}
upd(x);
if(isroot(x)) {
int t=max(sum[x].f0(),sum[x].f1());
dt[p[x]].a[0][0]+=t; dt[p[x]].a[0][1]+=t;
dt[p[x]].a[1][0]+=sum[x].f0();
}
x=p[x];
} upd(x);
lastans=max(sum[x].f0(),sum[x].f1());
printf("%d\n",lastans);
}
int main() {
//freopen("1.in","r",stdin);
//freopen("1.out","w",stdout);
read(n); read(m);
For(i,1,n) read(val[i]);
For(i,2,n) {
int u,v;
read(u); read(v);
add(u,v);
}
dfs1(1,0);
RT=dfs2(1);
For(i,1,m) {
int x,y;
read(x); read(y);
x^=lastans;
change(x,y);
}
return 0;
}
どこでもドア
写了这两道动态dp的题又用cai大佬教的树剖做了Qtree系列让我更加不理解动态dp到底是啥子矩阵又有什么用了,这道题让我终于不再一头雾水了。
设为 f ( e ) f(e) f(e)为当前子树中包含根节点的每种权值联通块数目的生成函数, g ( e ) g(e) g(e)为子树中所有的每种权值联通块数目的生成函数(g就是答案的生成函数啦)。y是x的儿子。
f x ( e ) = e v a l x ∏ f y ( e ) + e 0 g x ( e ) = f x ( e ) − e 0 + ∑ g y ( e ) f_x(e)=e^{val_x}\prod f_y(e)+e^0 \\g_x(e)=f_x(e)-e^0+\sum g_y(e) fx(e)=evalx∏fy(e)+e0gx(e)=fx(e)−e0+∑gy(e)
这个dp是可以用FWT优化的,FWT后可以直接乘除和加减,且可以最后再在根上IFWT回去得到需要的 g r o o t ( e ) g_{root}(e) groot(e),就非常方便了。
带修改我们仍然树剖,下面y为x的轻儿子。
f x ( e ) = ( e v a l x ∏ f y ( e ) ) ∗ f m s o n ( e ) + e 0 g x ( e ) = ( e v a l x ∏ f y ( e ) ) ∗ f m s o n ( e ) + g m s o n ( e ) + ∑ g y ( e ) f_x(e)=(e^{val_x}\prod f_y(e))*f_{mson}(e)+e^0 \\g_x(e)=(e^{val_x}\prod f_y(e))*f_{mson}(e)+g_{mson}(e)+\sum g_y(e) fx(e)=(evalx∏fy(e))∗fmson(e)+e0gx(e)=(evalx∏fy(e))∗fmson(e)+gmson(e)+∑gy(e)
写成矩阵
[ f x g x 1 ] = [ e v a l x ∏ f y ( e ) 0 e 0 e v a l x ∏ f y ( e ) 1 ∑ g y 0 0 1 ] [ f m s o n g m s o n 1 ] \begin{bmatrix} f_x \\ g_x\\ 1 \end{bmatrix}= \begin{bmatrix} e^{val_x}\prod f_y(e) & 0 &e^0\\ e^{val_x}\prod f_y(e) & 1 & \sum g_y \\ 0 & 0 & 1 \end{bmatrix} \begin{bmatrix} f_{mson} \\ g_{mson}\\ 1 \end{bmatrix} ⎣⎡fxgx1⎦⎤=⎣⎡evalx∏fy(e)evalx∏fy(e)0010e0∑gy1⎦⎤⎣⎡fmsongmson1⎦⎤
这样直接套上面那个模板,矩阵里面套数组,就可以 O ( n m l o g ∗ 3 3 ) O(nmlog*3^3) O(nmlog∗33)了,应该是可以过的,如果写树剖上线段树再带一个log就不知道能不能过了。
然鹅,看大佬们的题解都是线段树上维护子段和一样的东西,这两道取min,max我还比较容易感性理解,但这道题是个啥啊??
不要方张,我们发现这个矩阵非常玄妙啊
[ a 1 0 b 1 c 1 1 d 1 0 0 1 ] [ a 2 0 b 2 c 2 1 d 2 0 0 1 ] = [ a 1 a 2 0 a 1 b 2 + b 1 c 1 a 2 + c 2 1 c 1 b 2 + d 2 + d 1 0 0 1 ] \begin{bmatrix} a_1 & 0 &b_1\\ c_1 & 1 &d_1\\ 0 & 0 & 1 \end{bmatrix} \begin{bmatrix} a_2 & 0 &b_2\\ c_2 & 1 &d_2\\ 0 & 0 & 1 \end{bmatrix}= \begin{bmatrix} a_1a_2 & 0 &a_1b_2+b_1\\ c_1a_2+c_2 & 1 &c_1b_2+d_2+d_1\\ 0 & 0 & 1 \end{bmatrix} ⎣⎡a1c10010b1d11⎦⎤⎣⎡a2c20010b2d21⎦⎤=⎣⎡a1a2c1a2+c20010a1b2+b1c1b2+d2+d11⎦⎤
(在矩阵后面写个向量可以发现f和g就是a+b和c+d)
只需要对每个矩阵维护a,b,c,d就可以了。且这就是一个子段和的形式。有没有什么直接得到子段和的方式啊宸就不太清楚了,啊宸是直接把它当成矩阵看待然后套上个题模板的的。
传送门里的那两道题啊宸当初学了矩阵后试图用矩阵去做然后复杂度太高,现在再看发现写成矩阵后(懒得打上来了)同样可以用这种方式化成4or3个值再用子段和形式维护,困扰啊宸很久的问题得到解决惹。但是似乎重链上挂了set就算用了bst维护重链两个log的复杂度还是少不了的样子。
还有刚才那道模板的矩阵是不能化的。
哦还有这题因为取模又是除法是写了个可以支持取模乘除多个0的类。
//Achen
#include
#define For(i,a,b) for(register int i=(a);i<=(b);i++)
#define Rep(i,a,b) for(register int i=(a);i>=(b);i--)
const int N=30007,mod=10007,inv2=5004;
using namespace std;
typedef long long LL;
typedef double db;
const int inf=1e9;
int n,m,val[N],UP,K,inv[mod+7],hvson[N];
char op[10];
template<typename T>void read(T &x) {
char ch=getchar(); x=0; T f=1;
while(ch!='-'&&(ch<'0'||ch>'9')) ch=getchar();
if(ch=='-') f=-1,ch=getchar();
for(;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0'; x*=f;
}
void FWT(int a[],int f) {
for(int i=1;i<UP;i<<=1)
for(int j=0,pp=(i<<1);j<UP;j+=pp)
for(int k=0;k<i;k++) {
int x=a[j+k],y=a[j+k+i];
a[j+k]=(x+y)%mod; a[j+k+i]=(x-y+mod)%mod;
if(f==-1) (a[j+k]*=inv2)%=mod,(a[j+k+i]*=inv2)%=mod;
}
}
struct num {
int v,c;
}fy[N][128];
num operator *(const num&A,const num&B) { return (num){A.v*B.v%mod,A.c+B.c}; }
num operator /(const num&A,const num&B) { return (num){A.v*inv[B.v]%mod,A.c-B.c}; }
void get(num a[],int b[]) { For(i,0,UP-1) b[i]=a[i].c?0:a[i].v; }
void get(int a[],num b[]) { For(i,0,UP-1) b[i]=(a[i]?(num){a[i],0}:(num){1,1}); }
struct jz {
int a[128],b[128],c[128],d[128];
friend jz operator *(const jz&A,const jz&B) {
jz rs;
For(i,0,UP-1) {
rs.a[i]=A.a[i]*B.a[i]%mod;
rs.b[i]=(A.a[i]*B.b[i]%mod+A.b[i])%mod;
rs.c[i]=(A.c[i]*B.a[i]%mod+B.c[i])%mod;
rs.d[i]=(A.c[i]*B.b[i]%mod+A.d[i]+B.d[i])%mod;
}
return rs;
}
}dt[N],sum[N];
int prval[128][128];
void pre() {
inv[0]=inv[1]=1;
For(i,2,mod-1) inv[i]=(mod-mod/i*inv[mod%i]%mod)%mod;
For(i,0,UP-1) {
For(j,0,UP-1) prval[i][j]=0;
prval[i][i]=1;
FWT(prval[i],1);
}
}
void get_f(int a[],int val) {
For(i,0,UP-1) a[i]=prval[val][i];
}
int ecnt,fir[N],nxt[N<<1],to[N<<1];
void add(int u,int v) {
nxt[++ecnt]=fir[u]; fir[u]=ecnt; to[ecnt]=v;
nxt[++ecnt]=fir[v]; fir[v]=ecnt; to[ecnt]=u;
}
int p[N],sz[N],nsz[N],mson[N];
void dfs1(int x,int fa) {
sz[x]=1;
for(int i=fir[x];i;i=nxt[i]) if(to[i]!=fa) {
dfs1(to[i],x);
hvson[x]++;
sz[x]+=sz[to[i]];
if(!mson[x]||sz[to[i]]>sz[mson[x]]) mson[x]=to[i];
}
hvson[x]=hvson[x]>1?1:0;
nsz[x]=sz[x]-sz[mson[x]];
}
int ch[N][2];
#define lc ch[x][0]
#define rc ch[x][1]
int isroot(int x) { return ch[p[x]][0]!=x&&ch[p[x]][1]!=x; }
inline void upd(int x) {
if(lc) sum[x]=sum[lc]*dt[x]; else sum[x]=dt[x];
if(rc) sum[x]=sum[x]*sum[rc];
}
int sta[N],top;
int build(int l,int r) {
int tot=0,ntot=0;
For(i,l,r) tot+=nsz[sta[i]];
For(i,l,r) {
ntot+=nsz[sta[i]];
if(ntot*2>=tot) {
int x=sta[i];
lc=build(l,i-1); if(lc) p[lc]=x;
rc=build(i+1,r); if(rc) p[rc]=x;
upd(x); return x;
}
} return 0;
}
int RT;
int tpf[N];
num tpff[N];
void getac(int x) {
get_f(dt[x].a,val[x]);
if(hvson[x]) {
get(fy[x],tpf);
For(l,0,UP-1) (dt[x].a[l]*=tpf[l])%=mod;
}
For(i,0,UP-1) dt[x].c[i]=dt[x].a[i];
}
int dfs2(int x) {
for(int y=x;y;y=mson[y]) {
get_f(dt[y].b,0);
For(l,0,UP-1) dt[y].d[l]=0;
int fl=0;
for(int i=fir[y];i;i=nxt[i]) if(sz[to[i]]<sz[y]&&to[i]!=mson[y]) {
int z=dfs2(to[i]); p[z]=y;
For(l,0,UP-1) tpf[l]=(sum[z].a[l]+sum[z].b[l])%mod;
if(!fl) { get(tpf,fy[y]); fl=1; }
else { get(tpf,tpff); For(l,0,UP-1) fy[y][l]=fy[y][l]*tpff[l]; }
For(l,0,UP-1) (dt[y].d[l]+=sum[z].c[l]+sum[z].d[l])%=mod;
}
getac(y);
}
top=0;
for(int i=x;i;i=mson[i]) sta[++top]=i;
int rs=build(1,top);
return rs;
}
int lastans;
inline void change(int x,int vl) {
val[x]=vl;
getac(x);
while(x!=RT) {
if(isroot(x)&&p[x]) {
For(l,0,UP-1) tpf[l]=(sum[x].a[l]+sum[x].b[l])%mod;
get(tpf,tpff);
For(l,0,UP-1) fy[p[x]][l]=fy[p[x]][l]/tpff[l];
For(l,0,UP-1) dt[p[x]].d[l]=(dt[p[x]].d[l]-sum[x].c[l]-sum[x].d[l]+mod+mod)%mod;
}
upd(x);
if(isroot(x)&&p[x]) {
For(l,0,UP-1) tpf[l]=(sum[x].a[l]+sum[x].b[l])%mod;
get(tpf,tpff);
For(l,0,UP-1) fy[p[x]][l]=fy[p[x]][l]*tpff[l];
For(l,0,UP-1) dt[p[x]].d[l]=(dt[p[x]].d[l]+sum[x].c[l]+sum[x].d[l])%mod;
getac(p[x]);
}
x=p[x];
} upd(x);
}
int main() {
//freopen("cut.in","r",stdin);
//freopen("cut.out","w",stdout);
read(n); read(UP);
pre();
For(i,1,n) read(val[i]);
For(i,2,n) {
int u,v;
read(u); read(v);
add(u,v);
}
dfs1(1,0);
RT=dfs2(1);
int Q; read(Q);
For(cs,1,Q) {
int x,y;
scanf("%s",op);
if(op[0]=='C') {
read(x); read(y);
change(x,y);
}
else {
read(K);
For(i,0,UP-1) tpf[i]=(sum[RT].c[i]+sum[RT].d[i])%mod;
FWT(tpf,-1);
printf("%d\n",tpf[K]);
}
}
return 0;
}