题解:bzoj 2120 和 bzoj 3757 的结合版。
注意在排序的时候第一关键字为左端点所在的块,第二关键字为右端点所在的块,第三关键字为离当前询问最近的一次修改的时间(修改是在询问之前的)
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #define N 100003 #define LL long long using namespace std; int n,m,k,block,top; int next[N*2],v[N*2],mi[20]; int point[N],deep[N],st[N],num[N],fa[N][20]; int c[N],tot,sz,cnt,num1,num2; int belong[N],dfsn[N],vis[N],last[N]; LL ans[N],V[N],w[N],ans1; struct data{ int x,y,id,time; }q[N]; struct data1{ int pre,c,pos; }p[N]; int read() { int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } int cmp(data a,data b) { if (belong[a.x]==belong[b.x]&&belong[a.y]==belong[b.y]) return a.time<b.time; else if (belong[a.x]==belong[b.x]) return belong[a.y]<belong[b.y]; return belong[a.x]<belong[b.x]; } void add(int x,int y) { tot++; next[tot]=point[x]; point[x]=tot; v[tot]=y; tot++; next[tot]=point[y]; point[y]=tot; v[tot]=x; } int dfs(int x,int f) { int size=0; dfsn[x]=++sz; for (int i=1;i<=17;i++) { if (deep[x]-mi[i]<0) break; fa[x][i]=fa[fa[x][i-1]][i-1]; } for (int i=point[x];i;i=next[i]) if (v[i]!=f) { deep[v[i]]=deep[x]+1; fa[v[i]][0]=x; size+=dfs(v[i],x); if (size>=block) { ++cnt; for (int j=1;j<=size;j++) belong[st[top--]]=cnt; size=0; } } st[++top]=x; return size+1; } int lca(int x,int y) { if (deep[x]<deep[y]) swap(x,y); int k=deep[x]-deep[y]; for (int i=0;i<=17;i++) if (k>>i&1) x=fa[x][i]; if (x==y) return x; for (int i=17;i>=0;i--) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; return fa[x][0]; } void reserve(int x) { if(!vis[x]) { vis[x]=1; num[c[x]]++; ans1+=V[c[x]]*w[num[c[x]]]; } else { vis[x]=0; ans1-=V[c[x]]*w[num[c[x]]]; num[c[x]]--; } } void solve (int x,int y) { while (x!=y) { if(deep[x]>deep[y]) { reserve(x); x=fa[x][0]; } else{ reserve(y); y=fa[y][0]; } } } void change(int x,int k) { if (vis[x]) { reserve(x); c[x]=k; reserve(x); } else c[x]=k; } int main() { scanf("%d%d%d",&n,&m,&k); for (int i=1;i<=m;i++) V[i]=read(); for (int i=1;i<=n;i++) w[i]=read(); for (int i=1;i<n;i++) { int x,y; x=read(); y=read(); add(x,y); } for (int i=1;i<=n;i++) c[i]=read(),last[i]=c[i]; block=pow(n,2.0/3)*0.5; mi[0]=1; for (int i=1;i<=17;i++) mi[i]=mi[i-1]*2; deep[1]=1; dfs(1,0); while (top) belong[st[top--]]=cnt; for (int i=1;i<=k;i++) { int op,x,y; op=read(); x=read(); y=read(); if (op==0) { p[++num1].pos=x; p[num1].c=y; p[num1].pre=last[x]; last[x]=y; } else { q[++num2].x=x; q[num2].y=y; q[num2].id=num2; q[num2].time=num1; if (dfsn[x]>dfsn[y]) swap(q[num2].x,q[num2].y); } } sort(q+1,q+num2+1,cmp); for (int i=1;i<=q[1].time;i++) change(p[i].pos,p[i].c); solve(q[1].x,q[1].y); int t=lca(q[1].x,q[1].y); reserve(t); ans[q[1].id]=ans1; reserve(t); for (int i=2;i<=num2;i++) { for(int j=q[i-1].time+1;j<=q[i].time;j++) change(p[j].pos,p[j].c); for(int j=q[i-1].time;j>q[i].time;j--) change(p[j].pos,p[j].pre); solve(q[i-1].x,q[i].x); solve(q[i-1].y,q[i].y); int t=lca(q[i].x,q[i].y); reserve(t); ans[q[i].id]=ans1; reserve(t); } for (int i=1;i<=num2;i++) printf("%lld\n",ans[i]); }