为了练手速我花了半个小时打完了这道题。。然后debug的时候就。。23333
首先如果是一个序列显然可以用线段树区间修改,维护段中的颜色数量,左右端点的颜色来做吧。
树上也一样,我们可以把树上的区间转化为dfs序列中若干个连续区间,然后用树链剖分使区间的个数<logN,注意一下端点的问题就好了(说白了就是一道树链剖分裸题)。
AC代码如下:
#include<iostream> #include<cstdio> #include<cstring> #define N 100005 using namespace std; int n,m,dfsclk,bin[25],a[N],b[N],d[N],fa[N][17],sz[N],son[N],anc[N],pos[N]; int tot,fst[N],pnt[N<<1],nxt[N<<1],cvr[N<<2]; struct node{ int l,r,sum; }val[N<<2]; int read(){ int x=0; char ch=getchar(); while (ch<'0' || ch>'9') ch=getchar(); while (ch>='0' && ch<='9'){ x=x*10+ch-'0'; ch=getchar(); } return x; } void add(int x,int y){ pnt[++tot]=y; nxt[tot]=fst[x]; fst[x]=tot; } void dfs(int x){ sz[x]=1; int p,i; for (i=1; bin[i]<=d[x]; i++) fa[x][i]=fa[fa[x][i-1]][i-1]; for (p=fst[x]; p; p=nxt[p]){ int y=pnt[p]; if (y!=fa[x][0]){ fa[y][0]=x; d[y]=d[x]+1; dfs(y); sz[x]+=sz[y]; if (sz[y]>sz[son[x]]) son[x]=y; } } } void divide(int x,int tp){ pos[x]=++dfsclk; anc[x]=tp; int p; if (son[x]) divide(son[x],tp); for (p=fst[x]; p; p=nxt[p]){ int y=pnt[p]; if (y!=fa[x][0] && y!=son[x]) divide(y,y); } } int lca(int x,int y){ if (d[x]<d[y]) swap(x,y); int tmp=d[x]-d[y],i; for (i=0; bin[i]<=tmp; i++) if (tmp&bin[i]) x=fa[x][i]; for (i=16; i>=0; i--) if (fa[x][i]!=fa[y][i]){ x=fa[x][i]; y=fa[y][i]; } return (x==y)?x:fa[x][0]; } void maintain(int k){ int l=k<<1,r=l|1; val[k].sum=val[l].sum+val[r].sum; if (val[l].r==val[r].l) val[k].sum--; val[k].l=val[l].l; val[k].r=val[r].r; } void chg(int k,int v){ val[k].sum=1; val[k].l=val[k].r=cvr[k]=v; } void pushdown(int k){ if (cvr[k]!=-1){ chg(k<<1,cvr[k]); chg(k<<1|1,cvr[k]); cvr[k]=-1; } } void build(int k,int l,int r){ cvr[k]=-1; if (l==r){ val[k].l=val[k].r=a[l]; val[k].sum=1; return; } int mid=(l+r)>>1; build(k<<1,l,mid); build(k<<1|1,mid+1,r); maintain(k); } void ins(int k,int l,int r,int x,int y,int z){ if (l==x && r==y){ chg(k,z); return; } int mid=(l+r)>>1; pushdown(k); if (y<=mid) ins(k<<1,l,mid,x,y,z); else if (x>mid) ins(k<<1|1,mid+1,r,x,y,z); else{ ins(k<<1,l,mid,x,mid,z); ins(k<<1|1,mid+1,r,mid+1,y,z); } maintain(k); } node qry(int k,int l,int r,int x,int y){ if (cvr[k]!=-1){ node t1; t1.sum=1; t1.l=t1.r=cvr[k]; return t1; } if (l==x && r==y) return val[k]; int mid=(l+r)>>1; if (y<=mid) return qry(k<<1,l,mid,x,y); else if (x>mid) return qry(k<<1|1,mid+1,r,x,y); else{ node t1=qry(k<<1,l,mid,x,mid),t2=qry(k<<1|1,mid+1,r,mid+1,y); t1.sum+=t2.sum; if (t1.r==t2.l) t1.sum--; t1.r=t2.r; return t1; } } void mdy(int x,int y,int z){ for (; anc[x]!=anc[y]; x=fa[anc[x]][0]) ins(1,1,n,pos[anc[x]],pos[x],z); ins(1,1,n,pos[y],pos[x],z); } node solve(int x,int y){ if (anc[x]==anc[y]) return qry(1,1,n,pos[y],pos[x]); node t1=qry(1,1,n,pos[anc[x]],pos[x]),t2; x=fa[anc[x]][0]; for (; anc[x]!=anc[y]; x=fa[anc[x]][0]){ t2=qry(1,1,n,pos[anc[x]],pos[x]); t1.sum+=t2.sum; if (t1.l==t2.r) t1.sum--; t1.l=t2.l; } t2=qry(1,1,n,pos[y],pos[x]); t1.sum+=t2.sum; if (t1.l==t2.r) t1.sum--; t1.l=t2.l; return t1; } int main(){ n=read(); m=read(); int i; bin[0]=1; for (i=1; i<=17; i++) bin[i]=bin[i-1]<<1; for (i=1; i<=n; i++) b[i]=read(); for (i=1; i<n; i++){ int x=read(),y=read(); add(x,y); add(y,x); } dfs(1); divide(1,1); for (i=1; i<=n; i++) a[pos[i]]=b[i]; build(1,1,n); char ch; while (m--){ ch=getchar(); while (ch<'A' || ch>'Z') ch=getchar(); if (ch=='C'){ int x=read(),y=read(),z=read(),tmp=lca(x,y); mdy(x,tmp,z); mdy(y,tmp,z); } else{ int x=read(),y=read(),tmp=lca(x,y); node t1=solve(x,tmp),t2=solve(y,tmp); printf("%d\n",t1.sum+t2.sum-1); } } return 0; }
by lych
2016.3.8