BZOJ2243 [SDOI2011]染色

题意:树,路径染色,路径查询分了几段。

分析:

树链剖分套线段树,没写过,代码写得很乱,还犯了不少错,加了点注释,以后不能犯这种错了。

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define m ((L+R)>>1)
#define lc o<<1
#define rc o<<1|1
#define ls lc,L,m
#define rs rc,m+1,R
#define init 1,1,n
const int N = 100005; char op[1];
int n,q,e,x,y,z,tt,c[N*4],l[N*4],r[N*4],st[N*4],v[N],hd[N],nxt[N*2],to[N*2],d[N],p[N],tp[N],f[N],s[N],sz[N],pp[N];

void add(int x, int y) {to[++e] = y, nxt[e] = hd[x], hd[x] = e;}
//判断R > L和st!!!! 
void pu(int o, int L, int R) {
	if(R > L) c[o] = c[lc] + c[rc] + (r[lc] == l[rc] ? -1 : 0), l[o] = l[lc], r[o] = r[rc];
	if(~st[o]) c[o] = 1, l[o] = r[o] = st[o];
}
void pd(int o) {if(~st[o]) st[lc] = st[rc] = st[o], st[o] = -1;}
void bd(int o, int L, int R) {
	if(L == R) {
		l[o] = r[o] = v[pp[L]], c[o] = 1;	//不要写成v[L]或v[p[L]] 
		return;
	}
	bd(ls), bd(rs), pu(o, L, R);
}
int gt(int o, int L, int R, int p) {
	if(~st[o]) return st[o];
	if(L == R) return l[o];
	if(p <= m) return gt(ls, p);
	return gt(rs, p);
}
void up(int o, int L, int R, int l, int r, int p) {
    if(l <= L && r >= R) st[o] = p;
    else {
        pd(o); if(l <= m) up(ls, l, r, p); else pu(ls);
        if(r > m) up(rs, l, r, p); else pu(rs);
    }
    pu(o, L, R);
}
int qry(int o, int L, int R, int ll, int rr) {
	if(~st[o]) return 1;
    if(ll <= L && rr >= R) return c[o];
    if(rr <= m) return qry(ls, ll, rr); if(ll > m) return qry(rs, ll, rr);
    return qry(ls, ll, rr) + qry(rs, ll, rr) + (r[lc] == l[rc] ? -1 : 0);
}

void dfs1(int x) {
	int mx = 0; sz[x] = 1;
	for(int i = hd[x]; i; i = nxt[i]) if(!d[to[i]]) {
		d[to[i]] = d[x] + 1, f[to[i]] = x;
		dfs1(to[i]), sz[x] += sz[to[i]];
		if(sz[to[i]] > mx) mx = sz[to[i]], s[x] = to[i];
	}
}
void dfs2(int x) {
	if(s[x]) tp[s[x]] = tp[x], p[s[x]] = ++tt, pp[tt] = s[x], dfs2(s[x]);
	for(int i = hd[x]; i; i = nxt[i]) if(to[i] != f[x] && to[i] != s[x])
		tp[to[i]] = to[i], p[to[i]] = ++tt, pp[tt] = to[i], dfs2(to[i]);
}
int qr(int x, int y) {
	int ans = 0;
	while(tp[x] != tp[y]) {
		if(d[tp[x]] < d[tp[y]]) swap(x, y);	//深度判断用tp[x]和tp[y]!!! 
		ans += qry(init, p[tp[x]], p[x]);
		if(gt(init, p[tp[x]]) == gt(init, p[f[tp[x]]])) ans--;
		x = f[tp[x]];
	}
	if(d[x] < d[y]) swap(x, y);
	ans += qry(init, p[y], p[x]);	//最后用p[y],不能用p[tp[x]] 
	return ans;
}
void upd(int x, int y, int z) {
	while(tp[x] != tp[y]) {
		if(d[tp[x]] < d[tp[y]]) swap(x, y);
		up(init, p[tp[x]], p[x], z);
		x = f[tp[x]];
	}
	if(d[x] < d[y]) swap(x, y);
	up(init, p[y], p[x], z);
}

int main() {
	memset(st, -1, sizeof st);
	scanf("%d%d", &n, &q);
	for(int i = 1; i <= n; i++) scanf("%d", &v[i]);
	for(int i = 1; i < n; i++) scanf("%d%d", &x, &y), add(x, y), add(y, x);
	d[1] = 1, dfs1(1), tp[1] = 1, p[1] = ++tt, pp[tt] = 1, f[1] = 1, dfs2(1), bd(init); //别忘了初始化某些数组 
	while(q--) {
		scanf("%s%d%d", op, &x, &y);
		if(op[0] == 'Q') printf("%d\n", qr(x, y)); else scanf("%d", &z), upd(x, y, z);
	}
	return 0;
}


你可能感兴趣的:(线段树,树链剖分)