SDOI2018 原题识别(主席树)

题目链接

题目大意

给定 n n n个节点的树,其中包含一条非随机生成的长度为 k k k的链,剩下的节点均随机父节点连边。每个节点有一个随机的颜色,维护:
1.给定 x , y x,y x,y,求 x , y x,y x,y之间不同颜色数。
2.给定 x , y x,y x,y,对于所有满足分别在 x , y x,y x,y到根的路径上的点 a , b a,b a,b,求其询问1的答案之和。
n ≤ 1 0 5 , m ≤ 2 × 1 0 5 n\le 10^5,m\le 2\times 10^5 n105,m2×105

题解

码量比较大qwq……
我们先从链上的情况入手考虑。

链的情况

对于第一问,这是经典二维数点题。考虑 p i p_i pi表示 i i i之前第一个和它颜色相同的位置。我们以 ( i , p i ) (i,p_i) (i,pi)为坐标建点,询问不同颜色数就相当于询问 x x x坐标位于 [ l , r ] [l,r] [l,r] y y y坐标小于 l l l的点个数。直接主席树维护即可。
对于第二问,我们考虑点 i i i对答案的贡献。不妨设 x ≤ y x\le y xy,我们分三种情况讨论:
1. x < i ≤ y x<i\le y x<iy,此时贡献应该是 [ p i ≤ x ] ( x − p i ) ( y − i + 1 ) [p_i\le x](x-p_i)(y-i+1) [pix](xpi)(yi+1)
2. 1 ≤ i ≤ x 1\le i\le x 1ix a ≤ b a\le b ab,此时贡献应该是 ( i − p i ) ( y − i + 1 ) (i-p_i)(y-i+1) (ipi)(yi+1)
3. 1 ≤ i ≤ x 1\le i\le x 1ix a ≥ b a\ge b ab,此时贡献应该是 ( i − p i ) ( x − i + 1 ) (i-p_i)(x-i+1) (ipi)(xi+1)
如果直接把三种答案加起来的话会发现2,3两种情况中 a = b a=b a=b的部分算重了,减1即可。于是我们就需要维护上面的东西(2,3两个情况其实可以合起来):
第一种是
∑ i = x + 1 , p i ≤ x y ( x − p i ) ( y − i + 1 ) = ∑ i = x + 1 , p i ≤ x y x ( y + 1 ) − p i ( y + 1 ) − x i + p i i \sum_{i=x+1,p_i\le x}^y (x-p_i)(y-i+1)\\ =\sum_{i=x+1,p_i\le x}^y x(y+1)-p_i(y+1)-xi+p_ii i=x+1,pixy(xpi)(yi+1)=i=x+1,pixyx(y+1)pi(y+1)xi+pii
这个东西可以通过主席树维护四个值来计算:个数, i i i的和, p i p_i pi的和, p i i p_ii pii的和。
我们再来看第二种。
∑ i = 1 x ( i − p i ) ( x + y + 2 − 2 i ) = ∑ i = 1 x ( x + y + 2 ) ( i − p i ) − 2 i ( i − p i ) \sum_{i=1}^x(i-p_i)(x+y+2-2i)\\ =\sum_{i=1}^x(x+y+2)(i-p_i)-2i(i-p_i) i=1x(ipi)(x+y+22i)=i=1x(x+y+2)(ipi)2i(ipi)
这个东西没有了对 p i p_i pi的限制条件,因此直接前缀和维护即可。(当然如果你非要主席树的话我也不能拦着qwq)
到此为止,链的情况被我们在 O ( n l o g n ) O(nlogn) O(nlogn)的时间内做完了。

推广到树

注意到树除了那条链其它都是随机的,因此每个点到链距离的期望是 O ( l o g n ) O(logn) O(logn)的。每个颜色也是随机的,因此每个颜色出现次数的期望是 O ( 1 ) O(1) O(1)的。
也就是说,对于两个点 x , y x,y x,y的LCA,记为 l l l,必有一个点到其距离为 O ( l o g n ) O(logn) O(logn)。不妨就设这个点为 x x x,考虑第一问怎么做。
我们先计算出 [ l , y ] [l,y] [l,y]中不同的颜色数(注意下面的区间都指的是一条链),这个可以直接主席树。接下来做的事就是暴力枚举 [ x , l ) [x,l) [x,l)中的每个颜色,看看它是否在 [ l , y ] [l,y] [l,y]中出现了,直接统计。判断方法就是暴力枚举所有颜色和它相同的点即可。
因此第一问的复杂度也是 O ( n l o g n ) O(nlogn) O(nlogn)的。
考虑第二问,我们可以划分成如下三个子问题:
1. a ∈ [ 1 , l ) , b ∈ [ 1 , y ] a\in [1,l),b\in [1,y] a[1,l),b[1,y]。这实际上就是链的情况,主席树统计即可。
2. a ∈ [ l , x ] , b ∈ [ 1 , l ) a\in [l,x],b\in [1,l) a[l,x],b[1,l)。这其实也是一条链,我们可以稍微转化一下,先求出 a ∈ [ 1 , x ] , b ∈ [ 1 , l ) a\in [1,x],b\in [1,l) a[1,x],b[1,l)的答案,然后减去多算的。
多算的东西是 ∑ 2 ( i − p i ) ( l − i ) − 1 \sum 2(i-p_i)(l-i)-1 2(ipi)(li)1,直接前缀和就能维护。
3. a ∈ [ l , x ] , b ∈ [ l , y ] a\in [l,x],b\in [l,y] a[l,x],b[l,y]。这个情况很难算,我们也考虑分开计算贡献。考虑存在于 [ l , y ] [l,y] [l,y]中的点 i i i的贡献为 [ p i < l ] ( y − i + 1 ) ( x − l + 1 ) [p_i<l](y-i+1)(x-l+1) [pi<l](yi+1)(xl+1),主席树维护即可。
再考虑存在于 [ l , x ] [l,x] [l,x]中点 i i i的贡献,首先它必须是所有与它颜色相同的点中第一个在 [ l , x ] [l,x] [l,x]中出现的,它不能在 [ l , y ] [l,y] [l,y]中包含和它颜色相同的点。不妨令 j j j [ l , y ] [l,y] [l,y]中第一个和它颜色相同的点(如果不存在则为 y + 1 y+1 y+1),那么其贡献为 [ p i < l ] ( x − i + 1 ) ( j − l ) [p_i<l](x-i+1)(j-l) [pi<l](xi+1)(jl)
暴力枚举点是 O ( l o g n ) O(logn) O(logn)的,找第一次出现时 O ( 1 ) O(1) O(1)的,因此总复杂度还是 O ( n l o g n ) O(nlogn) O(nlogn)的,只是常数比较大。

#include 
namespace IOStream {
     
	const int MAXR = 1 << 23;
	char _READ_[MAXR], _PRINT_[MAXR];
	int _READ_POS_, _PRINT_POS_, _READ_LEN_;
	inline char readc() {
     
	#ifndef ONLINE_JUDGE
		return getchar();
	#endif
		if (!_READ_POS_) _READ_LEN_ = fread(_READ_, 1, MAXR, stdin);
		char c = _READ_[_READ_POS_++];
		if (_READ_POS_ == MAXR) _READ_POS_ = 0;
		if (_READ_POS_ > _READ_LEN_) return 0;
		return c;
	}
	template<typename T> inline void read(T &x) {
     
		x = 0; register int flag = 1, c;
		while (((c = readc()) < '0' || c > '9') && c != '-');
		if (c == '-') flag = -1; else x = c - '0';
		while ((c = readc()) >= '0' && c <= '9') x = x * 10 + c - '0';
		x *= flag;
	}
	template<typename T1, typename ...T2> inline void read(T1 &a, T2 &...x) {
     
		read(a), read(x...);
	}
	inline int reads(char *s) {
     
		register int len = 0, c;
		while (isspace(c = readc()) || !c);
		s[len++] = c;
		while (!isspace(c = readc()) && c) s[len++] = c;
		s[len] = 0;
		return len;
	}
	inline void ioflush() {
     
		fwrite(_PRINT_, 1, _PRINT_POS_, stdout), _PRINT_POS_ = 0;
		fflush(stdout);
	}
	inline void printc(char c) {
     
		_PRINT_[_PRINT_POS_++] = c;
		if (_PRINT_POS_ == MAXR) ioflush();
	}
	inline void prints(char *s) {
     
		for (int i = 0; s[i]; i++) printc(s[i]);
	}
	template<typename T> inline void print(T x, char c = '\n') {
     
		if (x < 0) printc('-'), x = -x;
		if (x) {
     
			static char sta[20];
			register int tp = 0;
			for (; x; x /= 10) sta[tp++] = x % 10 + '0';
			while (tp > 0) printc(sta[--tp]);
		} else printc('0');
		printc(c);
	}
	template<typename T1, typename ...T2> inline void print(T1 x, T2... y) {
     
		print(x, ' '), print(y...);
	}
}
using namespace IOStream;
using namespace std;
typedef long long ll;
#define cls(a) memset(a, 0, sizeof(a))

const int MAXN = 200005, MAXT = 2000005;
struct Edge {
      int to, next; } edge[MAXN];
int dfn[MAXN], st[20][MAXN], head[MAXN], lg[MAXN], tot, n, m, K, T;
void addedge(int u, int v) {
     
	edge[++tot] = (Edge) {
      v, head[u] };
	head[u] = tot;
}
int lst[MAXN], col[MAXN], dep[MAXN], rt[MAXN], app[MAXN], ed[MAXN];
struct Value {
     
	ll sum1, sum2, sum3, sum4;
	Value() {
      sum1 = sum2 = sum3 = sum4 = 0; }
	Value& operator+=(const Value &v) {
     
		sum1 += v.sum1, sum2 += v.sum2, sum3 += v.sum3, sum4 += v.sum4;
		return *this;
	}
	Value& operator-=(const Value &v) {
     
		sum1 -= v.sum1, sum2 -= v.sum2, sum3 -= v.sum3, sum4 -= v.sum4;
		return *this;
	}
} nd[MAXT];
//sum1=1,sum2=p[i],sum3=i,sum4=p[i]*i
ll pre2[MAXN], pre3[MAXN]; int ptot;
//pre1=1,pre2=i-p[i],pre3=i(i-p[i])
int ls[MAXT], rs[MAXT], par[MAXN], vis[MAXN];

//presistence segment tree
int update(int p, int x, int y, int l = 0, int r = n) {
     
	int q = ++ptot; nd[q] = nd[p];
	++nd[q].sum1, nd[q].sum2 += y, nd[q].sum3 += x, nd[q].sum4 += (ll)x * y;
	if (l == r) return q;
	int mid = (l + r) >> 1;
	if (y <= mid) ls[q] = update(ls[p], x, y, l, mid), rs[q] = rs[p];
	else rs[q] = update(rs[p], x, y, mid + 1, r), ls[q] = ls[p];
	return q;
}
void query(Value &v, int p, int q, int a, int b, int l = 0, int r = n) {
     //x in (p,q],y in [a,b]
	if (a > r || b < l || p == q) return;
	if (a <= l && b >= r) {
      v += nd[q], v -= nd[p]; return; }
	int mid = (l + r) >> 1;
	query(v, ls[p], ls[q], a, b, l, mid);
	query(v, rs[p], rs[q], a, b, mid + 1, r);
}

vector<int> pla[MAXN];
void dfs(int u, int fa) {
     
	st[0][dfn[u] = ++tot] = u, dep[u] = dep[fa] + 1;
	pla[col[u]].push_back(u), lst[u] = app[col[u]];
	int t = app[col[u]]; app[col[u]] = u;
	pre2[u] = pre2[fa] + dep[u] - dep[lst[u]];
	pre3[u] = pre3[fa] + (ll)(dep[u] - dep[lst[u]]) * dep[u];
	rt[u] = update(rt[fa], dep[u], dep[lst[u]]);
	for (int i = head[u]; i; i = edge[i].next) {
     
		dfs(edge[i].to, u);
		st[0][++tot] = u;
	}
	app[col[u]] = t, ed[u] = tot;
}
int get_min(int x, int y) {
      return dep[x] < dep[y] ? x : y; }
int get_lca(int x, int y) {
     
	x = dfn[x], y = dfn[y];
	if (x > y) swap(x, y);
	int l = lg[y - x + 1];
	return get_min(st[l][x], st[l][y - (1 << l) + 1]);
}
int on_link(int x, int y, int p) {
     //x is ancestor of y
	return dfn[x] <= dfn[p] && ed[x] >= dfn[p] &&
		dfn[p] <= dfn[y] && ed[p] >= dfn[y];
}

int solve1(int x, int y) {
     
	++tot;
	int la = get_lca(x, K), lb = get_lca(y, K);
	if (la > lb) swap(la, lb), swap(x, y);
	int l = get_lca(x, y);
	Value v; query(v, rt[par[l]], rt[y], 0, dep[l] - 1);
	int res = v.sum1;
	for (int i = x; i != l; i = par[i]) if (vis[col[i]] != tot) {
     
		vis[col[i]] = tot;
		int flag = 1;
		for (int j : pla[col[i]])
			if (on_link(l, y, j)) {
      flag = 0; break; }
		res += flag;
	}
	return res;
}
ll calc_link(int x, int y, const Value &v) {
     //x is ancestor of y
	int a = dep[x], b = dep[y];
	ll res = (v.sum1 * a - v.sum2) * (b + 1) - v.sum3 * a + v.sum4;
	return res + (a + b + 2) * pre2[x] - 2 * pre3[x] - a;
}
ll solve2(int x, int y) {
     
	++tot;
	int la = get_lca(x, K), lb = get_lca(y, K);
	if (la > lb) swap(la, lb), swap(x, y);
	int l = get_lca(x, y), pl = par[l], dl = dep[l];
	Value v1, v2;
	query(v1, rt[pl], rt[y], 0, dl - 1);
	query(v2, rt[pl], rt[x], 0, dl - 1);
	ll res = calc_link(pl, y, v1) + calc_link(pl, x, v2) -
		2 * (dl * pre2[pl] - pre3[pl]) + dl - 1;
	res += ((dep[y] + 1) * v1.sum1 - v1.sum3) * (dep[x] - dl + 1);
	int tp = 0;
	for (int i = x; i != l; i = par[i]) app[++tp] = i;
	app[++tp] = l;
	while (tp > 0) {
     
		int i = app[tp--];
		if (vis[col[i]] != tot) {
     
			vis[col[i]] = tot;
			int mn = dep[y] + 1;
			for (int j : pla[col[i]])
				if (on_link(l, y, j) && mn > dep[j]) mn = dep[j];
			res += (ll)(mn - dl) * (dep[x] - dep[i] + 1);
		}
	}
	return res;
}

unsigned int SA, SB, SC;
unsigned int rng61(){
     
    SA ^= SA << 16;
    SA ^= SA >> 5;
    SA ^= SA << 1;
    unsigned int t = SA;
    SA = SB;
    SB = SC;
    SC ^= t ^ SA;
    return SC;
}
void gen(){
     
    read(n, K, SA, SB, SC);
    for(int i = 2; i <= K; i++) addedge(par[i] = i - 1, i);
    for(int i = K + 1; i <= n; i++)
        addedge(par[i] = rng61() % (i - 1) + 1, i);
    for(int i = 1; i <= n; i++) col[i] = rng61() % n + 1;
}
int main() {
     
	for (read(T); T--;) {
     
		tot = 0, cls(head), cls(vis), cls(app);
		gen();
		for (int i = 1; i <= n; i++) pla[i].clear();
		dfs(1, ptot = tot = 0);
		for (int i = 2; i <= tot; i++) lg[i] = lg[i >> 1] + 1;
		for (int i = 1; i < 20; i++)
		for (int j = 1; j + (1 << i) - 1 <= tot; j++)
			st[i][j] = get_min(st[i - 1][j], st[i - 1][j + (1 << i >> 1)]);
		tot = 0;
		for (read(m); m--;) {
     
			int a, b, c; read(a, b, c);
			if (a == 1) print(solve1(b, c));
			else print(solve2(b, c));
		}
	}
	ioflush();
	return 0;
}

你可能感兴趣的:(主席树)