给定一颗树,每个节点有颜色和权值,你需要兹瓷四个操作:
1、改变一个点的颜色
2、改变一个点的权值
3、询问一条路径上和起点同颜色的点的和
4、询问一条路径上和起点同颜色的点的最大值
和数树数思路相同
树剖维护即可
#include<cstdio>
#include<algorithm>
#include<cmath>
#define fo(i,a,b) for(i=a;i<=b;i++)
using namespace std;
const int maxn=100000+10;
int sum[maxn*4],num[maxn*4],dfn[maxn],size[maxn],h[maxn*3],la[maxn*3],go[maxn*6],next[maxn*6];
bool bz[maxn];
int jump[maxn],a[maxn],c[maxn],d[maxn],ans[maxn];
int f[maxn][25];
struct dong{
int type,x,y,id;
};
dong ask[maxn*7];
int i,j,k,l,t,n,m,tot,top,cnt;
char ch,hc;
void add(int x,int y){
++tot;
if (!h[x]) h[x]=tot;
go[tot]=y;
if (la[x]) next[la[x]]=tot;
la[x]=tot;
}
void dfs(int x,int y){
int t=h[x];
size[x]=1;
while (t){
if (go[t]!=y){
dfs(go[t],x);
size[x]+=size[go[t]];
}
t=next[t];
}
}
void dg(int x,int y){
dfn[x]=++top;
d[x]=d[y]+1;
f[x][0]=y;
int t=h[x],j=0;
while (t){
if (go[t]!=y) {
if (!j||size[go[t]]>size[j]) j=go[t];
}
t=next[t];
}
if (!j) return;
jump[j]=jump[x];
dg(j,x);
t=h[x];
while (t){
if (go[t]!=y&&go[t]!=j){
jump[go[t]]=go[t];
dg(go[t],x);
}
t=next[t];
}
}
int lca(int x,int y){
int j;
if (d[x]<d[y]) swap(x,y);
if (d[x]!=d[y]){
j=floor(log(d[x]-d[y])/log(2));
while (j>=0){
if (d[f[x][j]]>d[y]) x=f[x][j];
j--;
}
x=f[x][0];
}
if (x==y) return x;
j=floor(log(d[x])/log(2));
while (j>=0){
if (f[x][j]!=f[y][j]){
x=f[x][j];
y=f[y][j];
}
j--;
}
return f[x][0];
}
char get(){
char ch=getchar();
while (ch<'A'||ch>'Z') ch=getchar();
return ch;
}
void change(int p,int l,int r,int a,int b){
if (l==r){
num[p]=sum[p]=b;
return;
}
int mid=(l+r)/2;
if (a<=mid) change(p*2,l,mid,a,b);else change(p*2+1,mid+1,r,a,b);
sum[p]=sum[p*2]+sum[p*2+1];
num[p]=max(num[p*2],num[p*2+1]);
}
int query1(int p,int l,int r,int a,int b){
if (l==a&&r==b) return sum[p];
int mid=(l+r)/2;
if (b<=mid) return query1(p*2,l,mid,a,b);
else if (a>mid) return query1(p*2+1,mid+1,r,a,b);
else return query1(p*2,l,mid,a,mid)+query1(p*2+1,mid+1,r,mid+1,b);
}
int query2(int p,int l,int r,int a,int b){
if (l==a&&r==b) return num[p];
int mid=(l+r)/2;
if (b<=mid) return query2(p*2,l,mid,a,b);
else if (a>mid) return query2(p*2+1,mid+1,r,a,b);
else return max(query2(p*2,l,mid,a,mid),query2(p*2+1,mid+1,r,mid+1,b));
}
int getsum(int u,int v){
int w=lca(u,v),ans=0,l;
while (d[u]>=d[w]){
if (d[jump[u]]<d[w]) l=w;else l=jump[u];
ans+=query1(1,1,n,dfn[l],dfn[u]);
u=f[jump[u]][0];
}
while (d[v]>=d[w]){
if (d[jump[v]]<d[w]) l=w;else l=jump[v];
ans+=query1(1,1,n,dfn[l],dfn[v]);
v=f[jump[v]][0];
}
ans-=query1(1,1,n,dfn[w],dfn[w]);
return ans;
}
int getnum(int u,int v){
int w=lca(u,v),ans=0,l;
while (d[u]>=d[w]){
if (d[jump[u]]<d[w]) l=w;else l=jump[u];
ans=max(ans,query2(1,1,n,dfn[l],dfn[u]));
u=f[jump[u]][0];
}
while (d[v]>=d[w]){
if (d[jump[v]]<d[w]) l=w;else l=jump[v];
ans=max(ans,query2(1,1,n,dfn[l],dfn[v]));
v=f[jump[v]][0];
}
return ans;
}
void solve(int x){
int t=h[x],now;
while (t){
now=go[t];
if (ask[now].type==1) change(1,1,n,dfn[ask[now].x],ask[now].y);
else if (ask[now].type==2) change(1,1,n,dfn[ask[now].x],0);
else if (ask[now].type==3) change(1,1,n,dfn[ask[now].x],ask[now].y);
else if (ask[now].type==4) ans[ask[now].id]=getsum(ask[now].x,ask[now].y);
else ans[ask[now].id]=getnum(ask[now].x,ask[now].y);
t=next[t];
}
}
int main(){
freopen("travel.in","r",stdin);freopen("travel.out","w",stdout);
scanf("%d%d",&n,&m);
fo(i,1,n) scanf("%d%d",&a[i],&c[i]);
fo(i,1,n-1){
scanf("%d%d",&j,&k);
add(j,k);
add(k,j);
}
dfs(1,0);
d[1]=1;
dg(1,0);
fo(j,1,floor(log(n)/log(2)))
fo(i,1,n)
f[i][j]=f[f[i][j-1]][j-1];
tot=0;
fill(h+1,h+100002,0);
fill(la+1,la+100002,0);
fo(i,1,n){
ask[++cnt].type=1;
ask[cnt].x=i;
ask[cnt].y=a[i];
add(c[i],cnt);
}
fo(i,1,m){
ch=get();hc=get();
scanf("%d%d",&j,&k);
if (ch=='C'&&hc=='C'){
if (c[j]!=k){
ask[++cnt].type=1;
ask[cnt].x=j;
ask[cnt].y=a[j];
add(k,cnt);
ask[++cnt].type=2;
ask[cnt].x=j;
add(c[j],cnt);
c[j]=k;
}
bz[i]=0;
}
else if (ch=='C'&&hc=='W'){
ask[++cnt].type=3;
ask[cnt].x=j;
ask[cnt].y=k;
a[j]=k;
add(c[j],cnt);
bz[i]=0;
}
else if (hc=='S'){
ask[++cnt].type=4;
ask[cnt].x=j;
ask[cnt].y=k;
ask[cnt].id=i;
add(c[j],cnt);
bz[i]=1;
}
else{
ask[++cnt].type=5;
ask[cnt].x=j;
ask[cnt].y=k;
ask[cnt].id=i;
add(c[j],cnt);
bz[i]=1;
}
}
fo(i,1,n){
ask[++cnt].type=1;
ask[cnt].x=i;
ask[cnt].y=a[i];
add(100001,cnt);
ask[++cnt].type=2;
ask[cnt].x=i;
add(c[i],cnt);
c[i]=100001;
}
fo(i,0,100000){
t=t;
solve(i);
}
fo(i,1,m)
if (bz[i]) printf("%d\n",ans[i]);
}