这个就不多说了,树链剖分模板题。
#include "stdio.h" #include "math.h" #include "iostream" #include "string.h" #include "vector" using namespace std; const int N=30005,Lg=14,inf=(int)(2e9); int n,tmp; int fa[N][Lg+1];//fa[i][j]表示i向上2^j的祖先 int v[N];//v[i]是第i个点的权值 struct seg{ int l,r,sum,mx;}tr[N<<2]; vector <int> e[N];//边集数组 int size[N];//size[i]表示大小 bool vis[N];//是否访问 int deep[N];//深度 int belong[N];//相当于重链编号,当前重链的根节点 int pos[N];//当前点对应的线段树编号 void init(){ cin>>n; int i,a,b; for (i=1;i<n;i++){ scanf("%d%d",&a,&b); e[a].push_back(b); e[b].push_back(a); } for(i=1;i<=n;i++) scanf("%d",v+i); } void dfs1(int root){ int i; for (i=1;i<=Lg;i++){ if(deep[root]<(1<<i)) break; fa[root][i]=fa[fa[root][i-1]][i-1];//倍增 } for (i=0;i<e[root].size();i++){ if(vis[e[root][i]]==true) continue; deep[e[root][i]]=deep[root]+1; fa[e[root][i]][0]=root; vis[e[root][i]]=true; deep[e[root][i]]=deep[root]+1; dfs1(e[root][i]); vis[e[root][i]]=false; size[root]+=size[e[root][i]]; } size[root]++; } void dfs2(int root,int bel){ int i,k=-1,siz=0; belong[root]=bel; pos[root]=++tmp; for (i=0;i<e[root].size();i++){ if(size[e[root][i]]>siz&&deep[root]<deep[e[root][i]]) k=e[root][i],siz=size[e[root][i]]; } if(k==-1) return ; dfs2(k,bel); for (i=0;i<e[root].size();i++){ if(k!=e[root][i]&&deep[root]<deep[e[root][i]]) dfs2(e[root][i],e[root][i]); } } void build(int l,int r,int tmp){ tr[tmp].l=l,tr[tmp].r=r; if(l==r) return; build(l,(l+r)>>1,tmp<<1); build((l+r+2)>>1,r,tmp<<1|1); } void change(int posi,int name,int cto){ if(tr[posi].l==tr[posi].r) {tr[posi].sum=tr[posi].mx=cto; return;} int mid=(tr[posi].l+tr[posi].r)>>1; if(name<=mid) change(posi<<1,name,cto); else change(posi<<1|1,name,cto); tr[posi].mx=max(tr[posi<<1].mx,tr[posi<<1|1].mx); tr[posi].sum=tr[posi<<1].sum+tr[posi<<1|1].sum; } int querymax(int p,int l,int r){ if(tr[p].l==l&&tr[p].r==r) return tr[p].mx; int mid=(tr[p].l+tr[p].r)>>1; if(r<=mid) return querymax(p<<1,l,r); if(l>mid) return querymax(p<<1|1,l,r); return max(querymax(p<<1,l,mid),querymax(p<<1|1,mid+1,r)); } int querysum(int p,int l,int r){ if(tr[p].l==l&&tr[p].r==r) return tr[p].sum; int mid=(tr[p].l+tr[p].r)>>1; if(r<=mid) return querysum(p<<1,l,r); if(l>mid) return querysum(p<<1|1,l,r); return querysum(p<<1,l,mid)+querysum(p<<1|1,mid+1,r); } int lca(int x,int y){ if(deep[x]<deep[y]) x+=y,y=x-y,x-=y; int len=deep[x]-deep[y],i; for (i=Lg;i>=0;i--) if(len&(1<<i)) x=fa[x][i]; for (i=Lg;i>=0;i--) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; if(x==y) return x; return fa[x][0]; } int qmax(int u,int v){ int mx=-inf; while (belong[u]!=belong[v]){ mx=max(mx,querymax(1,pos[belong[u]],pos[u])); u=fa[belong[u]][0]; } mx=max(mx,querymax(1,pos[v],pos[u])); return mx; } int qsum(int u,int v){ int sum=0; while (belong[u]!=belong[v]){ sum+=querysum(1,pos[belong[u]],pos[u]); u=fa[belong[u]][0]; } sum+=querysum(1,pos[v],pos[u]); return sum; } void work(){ int i,q,a,b; char s[10]; build(1,n,1); for (i=1;i<=n;i++) change(1,pos[i],v[i]); scanf("%d",&q); for (i=1;i<=q;i++){ scanf("%s%d%d",s,&a,&b); if(s[0]=='C') change(1,pos[a],b),v[a]=b; else{ int t=lca(a,b); if(s[1]=='M') printf("%d\n",max(qmax(a,t),qmax(b,t))); else printf("%d\n",qsum(a,t)+qsum(b,t)-v[t]); } } } int main(){ init(); vis[1]=true; dfs1(1); dfs2(1,1); vis[1]=false; work(); return 0; }