题意:树,路径染色,路径查询分了几段。
分析:
树链剖分套线段树,没写过,代码写得很乱,还犯了不少错,加了点注释,以后不能犯这种错了。
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; #define m ((L+R)>>1) #define lc o<<1 #define rc o<<1|1 #define ls lc,L,m #define rs rc,m+1,R #define init 1,1,n const int N = 100005; char op[1]; int n,q,e,x,y,z,tt,c[N*4],l[N*4],r[N*4],st[N*4],v[N],hd[N],nxt[N*2],to[N*2],d[N],p[N],tp[N],f[N],s[N],sz[N],pp[N]; void add(int x, int y) {to[++e] = y, nxt[e] = hd[x], hd[x] = e;} //判断R > L和st!!!! void pu(int o, int L, int R) { if(R > L) c[o] = c[lc] + c[rc] + (r[lc] == l[rc] ? -1 : 0), l[o] = l[lc], r[o] = r[rc]; if(~st[o]) c[o] = 1, l[o] = r[o] = st[o]; } void pd(int o) {if(~st[o]) st[lc] = st[rc] = st[o], st[o] = -1;} void bd(int o, int L, int R) { if(L == R) { l[o] = r[o] = v[pp[L]], c[o] = 1; //不要写成v[L]或v[p[L]] return; } bd(ls), bd(rs), pu(o, L, R); } int gt(int o, int L, int R, int p) { if(~st[o]) return st[o]; if(L == R) return l[o]; if(p <= m) return gt(ls, p); return gt(rs, p); } void up(int o, int L, int R, int l, int r, int p) { if(l <= L && r >= R) st[o] = p; else { pd(o); if(l <= m) up(ls, l, r, p); else pu(ls); if(r > m) up(rs, l, r, p); else pu(rs); } pu(o, L, R); } int qry(int o, int L, int R, int ll, int rr) { if(~st[o]) return 1; if(ll <= L && rr >= R) return c[o]; if(rr <= m) return qry(ls, ll, rr); if(ll > m) return qry(rs, ll, rr); return qry(ls, ll, rr) + qry(rs, ll, rr) + (r[lc] == l[rc] ? -1 : 0); } void dfs1(int x) { int mx = 0; sz[x] = 1; for(int i = hd[x]; i; i = nxt[i]) if(!d[to[i]]) { d[to[i]] = d[x] + 1, f[to[i]] = x; dfs1(to[i]), sz[x] += sz[to[i]]; if(sz[to[i]] > mx) mx = sz[to[i]], s[x] = to[i]; } } void dfs2(int x) { if(s[x]) tp[s[x]] = tp[x], p[s[x]] = ++tt, pp[tt] = s[x], dfs2(s[x]); for(int i = hd[x]; i; i = nxt[i]) if(to[i] != f[x] && to[i] != s[x]) tp[to[i]] = to[i], p[to[i]] = ++tt, pp[tt] = to[i], dfs2(to[i]); } int qr(int x, int y) { int ans = 0; while(tp[x] != tp[y]) { if(d[tp[x]] < d[tp[y]]) swap(x, y); //深度判断用tp[x]和tp[y]!!! ans += qry(init, p[tp[x]], p[x]); if(gt(init, p[tp[x]]) == gt(init, p[f[tp[x]]])) ans--; x = f[tp[x]]; } if(d[x] < d[y]) swap(x, y); ans += qry(init, p[y], p[x]); //最后用p[y],不能用p[tp[x]] return ans; } void upd(int x, int y, int z) { while(tp[x] != tp[y]) { if(d[tp[x]] < d[tp[y]]) swap(x, y); up(init, p[tp[x]], p[x], z); x = f[tp[x]]; } if(d[x] < d[y]) swap(x, y); up(init, p[y], p[x], z); } int main() { memset(st, -1, sizeof st); scanf("%d%d", &n, &q); for(int i = 1; i <= n; i++) scanf("%d", &v[i]); for(int i = 1; i < n; i++) scanf("%d%d", &x, &y), add(x, y), add(y, x); d[1] = 1, dfs1(1), tp[1] = 1, p[1] = ++tt, pp[tt] = 1, f[1] = 1, dfs2(1), bd(init); //别忘了初始化某些数组 while(q--) { scanf("%s%d%d", op, &x, &y); if(op[0] == 'Q') printf("%d\n", qr(x, y)); else scanf("%d", &z), upd(x, y, z); } return 0; }