题目大意
解题思路
考虑离线询问,把所有可能的点用虚树建出来,用lct维护虚树即可。
code
#include
#include
#include
#include
#define LF double
#define LL long long
#define ULL unsigned int
#define fo(i,j,k) for(int i=j;i<=k;i++)
#define fd(i,j,k) for(int i=j;i>=k;i--)
#define fr(i,j) for(int i=begin[j];i;i=next[i])
using namespace std;
int max(int x,int y){return (x>y)?x:y;}
int min(int x,int y){return (xint const mn=1e5+9,mm=4*1e5+9,mq=2*1e4+9,inf=1e9+7;
int n,q,lg2,pon,time,gra,w[mn],c[mn],op[mn],x[mn],y[mn],f[mn][20],beg[mn],
end[mn],begin[mn],to[mm],next[mm],dfn[mn],dep[mn],val[mm],sum[mm],
mx[mm],son[mm][2],fa[mm],par[mm],tag[mm],st[mm];
struct rec{
int p,c,to;
};
rec a[mm];
void insert(int u,int v){
to[++gra]=v;
next[gra]=begin[u];
begin[u]=gra;
}
void dfs(int now,int pre){
dfn[now]=++time;
dep[now]=dep[pre]+1;
fr(i,now)if(to[i]!=pre){
f[to[i]][0]=now;
dfs(to[i],now);
}
}
bool cmp(rec i,rec j){
return (i.cbool cm2(rec i,rec j){
return (i.pint lc(int u,int v){
if(dep[u]0)if(dep[f[u][i]]>=dep[v])u=f[u][i];
if(u==v)return u;
fd(i,lg2,0)if(f[u][i]!=f[v][i])u=f[u][i],v=f[v][i];
return f[u][0];
}
int find(int x,int y){
int l=beg[x],r=end[x];
while(l!=r){
int md=(l+r)>>1;
if(a[md].c>=y)r=md;
else l=md+1;
}
return a[l].to;
}
int way(int now){
return son[fa[now]][1]==now;
}
void update(int now){
if(!now)return;
sum[now]=val[now]+sum[son[now][0]]+sum[son[now][1]];
mx[now]=max(val[now],max(mx[son[now][0]],mx[son[now][1]]));
}
void rotate(int now){
int tmp=fa[now],fx=way(now);
son[tmp][fx]=son[now][!fx];
fa[son[now][!fx]]=tmp;
son[fa[tmp]][way(tmp)]=now;
fa[now]=fa[tmp];
son[now][!fx]=tmp;
fa[tmp]=now;
update(tmp);
update(now);
}
void uptag(int now){
if(tag[now]){
int tmp=son[now][0],tm2=son[now][1];
swap(son[tmp][0],son[tmp][1]);
swap(son[tm2][0],son[tm2][1]);
tag[tmp]^=1;
tag[tm2]^=1;
tag[now]=0;
}
}
void splay(int now,int rt){
while(fa[now]!=rt){
if(fa[fa[now]]==rt)uptag(fa[now]),uptag(now),rotate(now);
else{
uptag(fa[fa[now]]),uptag(fa[now]),uptag(now);
if(way(now)==way(fa[now]))rotate(fa[now]);
else rotate(now);
rotate(now);
}
}
uptag(now);
}
int get(int now){
uptag(now);
while(son[now][0]){
now=son[now][0];
uptag(now);
}
return now;
}
void indep(int now){
splay(now,0);
if(!son[now][1])return;
int tmp=get(son[now][1]);
splay(tmp,now);
fa[son[now][1]]=0;
par[son[now][1]]=now;
son[now][1]=0;
update(now);
}
void access(int now){
indep(now);
while(1){
splay(now,0);
int tmp=get(now),
tm2=par[tmp];
if(!tm2)break;
par[tmp]=0;
indep(tm2);
son[tm2][1]=now;
fa[now]=tm2;
update(tm2);
now=tm2;
}
}
void mroot(int now){
access(now);
splay(now,0);
swap(son[now][0],son[now][1]);
tag[now]=1;
}
void oper(int now,int tmp){
splay(now,0);
val[now]=tmp;
update(now);
}
int main(){
freopen("d.in","r",stdin);
freopen("d.out","w",stdout);
scanf("%d%d",&n,&q);
int tmp=0;
fo(i,1,n)scanf("%d%d",&w[i],&c[i]),a[++tmp].p=i,a[tmp].c=c[i];
fo(i,1,n-1){
int u,v;
scanf("%d%d",&u,&v);
insert(u,v);
insert(v,u);
}
scanf("\n");
fo(i,1,q){
char c1,c2;
scanf("%c%c",&c1,&c2);
if(c1=='C'){
if(c2=='C')op[i]=1;
else op[i]=2;
}else{
if(c2=='S')op[i]=3;
else op[i]=4;
}
scanf("%d%d\n",&x[i],&y[i]);
if(op[i]==1)a[++tmp].p=x[i],a[tmp].c=y[i];
}
dfs(1,0);lg2=log(n)/log(2);
fo(j,1,lg2)fo(i,1,n)f[i][j]=f[f[i][j-1]][j-1];
sort(a+1,a+tmp+1,cmp);
int tm2=0;
fo(i,1,tmp)if((a[i].p!=a[i-1].p)||(a[i].c!=a[i-1].c)){
a[++tm2].p=a[i].p;
a[tm2].c=a[i].c;
}
tmp=tm2;
fo(i,1,tmp){
a[i].to=++pon;
if(a[i].c==a[i-1].c){
int lca=lc(a[i].p,a[i-1].p);
while((st[0]>1)&&(dfn[a[st[st[0]-1]].p]>dfn[lca]))
st[0]--,par[a[st[st[0]+1]].to]=a[st[st[0]]].to;
if((st[0]>1)&&(dfn[a[st[st[0]-1]].p]==dfn[lca]))
st[0]--,par[a[st[st[0]+1]].to]=a[st[st[0]]].to;
else if(lca!=a[st[st[0]]].p){
a[++tm2].p=lca;
a[tm2].c=a[i].c;
a[tm2].to=par[a[st[st[0]]].to]=++pon;
st[st[0]]=tm2;
}
}else{
while(st[0]>1)st[0]--,par[a[st[st[0]+1]].to]=a[st[st[0]]].to;
st[0]=0;
}
st[++st[0]]=i;
}
while(st[0]>1)st[0]--,par[a[st[st[0]+1]].to]=a[st[st[0]]].to;
tmp=tm2;
sort(a+1,a+tmp+1,cm2);
for(int i=1,j;i<=tmp;i=j){
for(j=i;a[i].p==a[j].p;j++);
beg[a[i].p]=i;end[a[i].p]=j-1;
}
fo(i,1,n){
int tmp=find(i,c[i]);
oper(tmp,w[i]);
}
fo(i,1,q){
int tmp=find(x[i],c[x[i]]);
if(op[i]==1){
int tm2=find(x[i],y[i]);
oper(tmp,0);oper(tm2,w[x[i]]);
c[x[i]]=y[i];
}else if(op[i]==2){
oper(tmp,y[i]);
w[x[i]]=y[i];
}else{
int tm2=find(y[i],c[y[i]]);
mroot(tmp);
access(tm2);
splay(tm2,0);
if(op[i]==3)printf("%d\n",sum[tm2]);
else printf("%d\n",mx[tm2]);
}
}
return 0;
}