题目链接
给定 n n n个节点的树,其中包含一条非随机生成的长度为 k k k的链,剩下的节点均随机父节点连边。每个节点有一个随机的颜色,维护:
1.给定 x , y x,y x,y,求 x , y x,y x,y之间不同颜色数。
2.给定 x , y x,y x,y,对于所有满足分别在 x , y x,y x,y到根的路径上的点 a , b a,b a,b,求其询问1的答案之和。
n ≤ 1 0 5 , m ≤ 2 × 1 0 5 n\le 10^5,m\le 2\times 10^5 n≤105,m≤2×105
码量比较大qwq……
我们先从链上的情况入手考虑。
对于第一问,这是经典二维数点题。考虑 p i p_i pi表示 i i i之前第一个和它颜色相同的位置。我们以 ( i , p i ) (i,p_i) (i,pi)为坐标建点,询问不同颜色数就相当于询问 x x x坐标位于 [ l , r ] [l,r] [l,r]且 y y y坐标小于 l l l的点个数。直接主席树维护即可。
对于第二问,我们考虑点 i i i对答案的贡献。不妨设 x ≤ y x\le y x≤y,我们分三种情况讨论:
1. x < i ≤ y x<i\le y x<i≤y,此时贡献应该是 [ p i ≤ x ] ( x − p i ) ( y − i + 1 ) [p_i\le x](x-p_i)(y-i+1) [pi≤x](x−pi)(y−i+1)。
2. 1 ≤ i ≤ x 1\le i\le x 1≤i≤x且 a ≤ b a\le b a≤b,此时贡献应该是 ( i − p i ) ( y − i + 1 ) (i-p_i)(y-i+1) (i−pi)(y−i+1)。
3. 1 ≤ i ≤ x 1\le i\le x 1≤i≤x且 a ≥ b a\ge b a≥b,此时贡献应该是 ( i − p i ) ( x − i + 1 ) (i-p_i)(x-i+1) (i−pi)(x−i+1)。
如果直接把三种答案加起来的话会发现2,3两种情况中 a = b a=b a=b的部分算重了,减1即可。于是我们就需要维护上面的东西(2,3两个情况其实可以合起来):
第一种是
∑ i = x + 1 , p i ≤ x y ( x − p i ) ( y − i + 1 ) = ∑ i = x + 1 , p i ≤ x y x ( y + 1 ) − p i ( y + 1 ) − x i + p i i \sum_{i=x+1,p_i\le x}^y (x-p_i)(y-i+1)\\ =\sum_{i=x+1,p_i\le x}^y x(y+1)-p_i(y+1)-xi+p_ii i=x+1,pi≤x∑y(x−pi)(y−i+1)=i=x+1,pi≤x∑yx(y+1)−pi(y+1)−xi+pii
这个东西可以通过主席树维护四个值来计算:个数, i i i的和, p i p_i pi的和, p i i p_ii pii的和。
我们再来看第二种。
∑ i = 1 x ( i − p i ) ( x + y + 2 − 2 i ) = ∑ i = 1 x ( x + y + 2 ) ( i − p i ) − 2 i ( i − p i ) \sum_{i=1}^x(i-p_i)(x+y+2-2i)\\ =\sum_{i=1}^x(x+y+2)(i-p_i)-2i(i-p_i) i=1∑x(i−pi)(x+y+2−2i)=i=1∑x(x+y+2)(i−pi)−2i(i−pi)
这个东西没有了对 p i p_i pi的限制条件,因此直接前缀和维护即可。(当然如果你非要主席树的话我也不能拦着qwq)
到此为止,链的情况被我们在 O ( n l o g n ) O(nlogn) O(nlogn)的时间内做完了。
注意到树除了那条链其它都是随机的,因此每个点到链距离的期望是 O ( l o g n ) O(logn) O(logn)的。每个颜色也是随机的,因此每个颜色出现次数的期望是 O ( 1 ) O(1) O(1)的。
也就是说,对于两个点 x , y x,y x,y的LCA,记为 l l l,必有一个点到其距离为 O ( l o g n ) O(logn) O(logn)。不妨就设这个点为 x x x,考虑第一问怎么做。
我们先计算出 [ l , y ] [l,y] [l,y]中不同的颜色数(注意下面的区间都指的是一条链),这个可以直接主席树。接下来做的事就是暴力枚举 [ x , l ) [x,l) [x,l)中的每个颜色,看看它是否在 [ l , y ] [l,y] [l,y]中出现了,直接统计。判断方法就是暴力枚举所有颜色和它相同的点即可。
因此第一问的复杂度也是 O ( n l o g n ) O(nlogn) O(nlogn)的。
考虑第二问,我们可以划分成如下三个子问题:
1. a ∈ [ 1 , l ) , b ∈ [ 1 , y ] a\in [1,l),b\in [1,y] a∈[1,l),b∈[1,y]。这实际上就是链的情况,主席树统计即可。
2. a ∈ [ l , x ] , b ∈ [ 1 , l ) a\in [l,x],b\in [1,l) a∈[l,x],b∈[1,l)。这其实也是一条链,我们可以稍微转化一下,先求出 a ∈ [ 1 , x ] , b ∈ [ 1 , l ) a\in [1,x],b\in [1,l) a∈[1,x],b∈[1,l)的答案,然后减去多算的。
多算的东西是 ∑ 2 ( i − p i ) ( l − i ) − 1 \sum 2(i-p_i)(l-i)-1 ∑2(i−pi)(l−i)−1,直接前缀和就能维护。
3. a ∈ [ l , x ] , b ∈ [ l , y ] a\in [l,x],b\in [l,y] a∈[l,x],b∈[l,y]。这个情况很难算,我们也考虑分开计算贡献。考虑存在于 [ l , y ] [l,y] [l,y]中的点 i i i的贡献为 [ p i < l ] ( y − i + 1 ) ( x − l + 1 ) [p_i<l](y-i+1)(x-l+1) [pi<l](y−i+1)(x−l+1),主席树维护即可。
再考虑存在于 [ l , x ] [l,x] [l,x]中点 i i i的贡献,首先它必须是所有与它颜色相同的点中第一个在 [ l , x ] [l,x] [l,x]中出现的,它不能在 [ l , y ] [l,y] [l,y]中包含和它颜色相同的点。不妨令 j j j为 [ l , y ] [l,y] [l,y]中第一个和它颜色相同的点(如果不存在则为 y + 1 y+1 y+1),那么其贡献为 [ p i < l ] ( x − i + 1 ) ( j − l ) [p_i<l](x-i+1)(j-l) [pi<l](x−i+1)(j−l)。
暴力枚举点是 O ( l o g n ) O(logn) O(logn)的,找第一次出现时 O ( 1 ) O(1) O(1)的,因此总复杂度还是 O ( n l o g n ) O(nlogn) O(nlogn)的,只是常数比较大。
#include
namespace IOStream {
const int MAXR = 1 << 23;
char _READ_[MAXR], _PRINT_[MAXR];
int _READ_POS_, _PRINT_POS_, _READ_LEN_;
inline char readc() {
#ifndef ONLINE_JUDGE
return getchar();
#endif
if (!_READ_POS_) _READ_LEN_ = fread(_READ_, 1, MAXR, stdin);
char c = _READ_[_READ_POS_++];
if (_READ_POS_ == MAXR) _READ_POS_ = 0;
if (_READ_POS_ > _READ_LEN_) return 0;
return c;
}
template<typename T> inline void read(T &x) {
x = 0; register int flag = 1, c;
while (((c = readc()) < '0' || c > '9') && c != '-');
if (c == '-') flag = -1; else x = c - '0';
while ((c = readc()) >= '0' && c <= '9') x = x * 10 + c - '0';
x *= flag;
}
template<typename T1, typename ...T2> inline void read(T1 &a, T2 &...x) {
read(a), read(x...);
}
inline int reads(char *s) {
register int len = 0, c;
while (isspace(c = readc()) || !c);
s[len++] = c;
while (!isspace(c = readc()) && c) s[len++] = c;
s[len] = 0;
return len;
}
inline void ioflush() {
fwrite(_PRINT_, 1, _PRINT_POS_, stdout), _PRINT_POS_ = 0;
fflush(stdout);
}
inline void printc(char c) {
_PRINT_[_PRINT_POS_++] = c;
if (_PRINT_POS_ == MAXR) ioflush();
}
inline void prints(char *s) {
for (int i = 0; s[i]; i++) printc(s[i]);
}
template<typename T> inline void print(T x, char c = '\n') {
if (x < 0) printc('-'), x = -x;
if (x) {
static char sta[20];
register int tp = 0;
for (; x; x /= 10) sta[tp++] = x % 10 + '0';
while (tp > 0) printc(sta[--tp]);
} else printc('0');
printc(c);
}
template<typename T1, typename ...T2> inline void print(T1 x, T2... y) {
print(x, ' '), print(y...);
}
}
using namespace IOStream;
using namespace std;
typedef long long ll;
#define cls(a) memset(a, 0, sizeof(a))
const int MAXN = 200005, MAXT = 2000005;
struct Edge {
int to, next; } edge[MAXN];
int dfn[MAXN], st[20][MAXN], head[MAXN], lg[MAXN], tot, n, m, K, T;
void addedge(int u, int v) {
edge[++tot] = (Edge) {
v, head[u] };
head[u] = tot;
}
int lst[MAXN], col[MAXN], dep[MAXN], rt[MAXN], app[MAXN], ed[MAXN];
struct Value {
ll sum1, sum2, sum3, sum4;
Value() {
sum1 = sum2 = sum3 = sum4 = 0; }
Value& operator+=(const Value &v) {
sum1 += v.sum1, sum2 += v.sum2, sum3 += v.sum3, sum4 += v.sum4;
return *this;
}
Value& operator-=(const Value &v) {
sum1 -= v.sum1, sum2 -= v.sum2, sum3 -= v.sum3, sum4 -= v.sum4;
return *this;
}
} nd[MAXT];
//sum1=1,sum2=p[i],sum3=i,sum4=p[i]*i
ll pre2[MAXN], pre3[MAXN]; int ptot;
//pre1=1,pre2=i-p[i],pre3=i(i-p[i])
int ls[MAXT], rs[MAXT], par[MAXN], vis[MAXN];
//presistence segment tree
int update(int p, int x, int y, int l = 0, int r = n) {
int q = ++ptot; nd[q] = nd[p];
++nd[q].sum1, nd[q].sum2 += y, nd[q].sum3 += x, nd[q].sum4 += (ll)x * y;
if (l == r) return q;
int mid = (l + r) >> 1;
if (y <= mid) ls[q] = update(ls[p], x, y, l, mid), rs[q] = rs[p];
else rs[q] = update(rs[p], x, y, mid + 1, r), ls[q] = ls[p];
return q;
}
void query(Value &v, int p, int q, int a, int b, int l = 0, int r = n) {
//x in (p,q],y in [a,b]
if (a > r || b < l || p == q) return;
if (a <= l && b >= r) {
v += nd[q], v -= nd[p]; return; }
int mid = (l + r) >> 1;
query(v, ls[p], ls[q], a, b, l, mid);
query(v, rs[p], rs[q], a, b, mid + 1, r);
}
vector<int> pla[MAXN];
void dfs(int u, int fa) {
st[0][dfn[u] = ++tot] = u, dep[u] = dep[fa] + 1;
pla[col[u]].push_back(u), lst[u] = app[col[u]];
int t = app[col[u]]; app[col[u]] = u;
pre2[u] = pre2[fa] + dep[u] - dep[lst[u]];
pre3[u] = pre3[fa] + (ll)(dep[u] - dep[lst[u]]) * dep[u];
rt[u] = update(rt[fa], dep[u], dep[lst[u]]);
for (int i = head[u]; i; i = edge[i].next) {
dfs(edge[i].to, u);
st[0][++tot] = u;
}
app[col[u]] = t, ed[u] = tot;
}
int get_min(int x, int y) {
return dep[x] < dep[y] ? x : y; }
int get_lca(int x, int y) {
x = dfn[x], y = dfn[y];
if (x > y) swap(x, y);
int l = lg[y - x + 1];
return get_min(st[l][x], st[l][y - (1 << l) + 1]);
}
int on_link(int x, int y, int p) {
//x is ancestor of y
return dfn[x] <= dfn[p] && ed[x] >= dfn[p] &&
dfn[p] <= dfn[y] && ed[p] >= dfn[y];
}
int solve1(int x, int y) {
++tot;
int la = get_lca(x, K), lb = get_lca(y, K);
if (la > lb) swap(la, lb), swap(x, y);
int l = get_lca(x, y);
Value v; query(v, rt[par[l]], rt[y], 0, dep[l] - 1);
int res = v.sum1;
for (int i = x; i != l; i = par[i]) if (vis[col[i]] != tot) {
vis[col[i]] = tot;
int flag = 1;
for (int j : pla[col[i]])
if (on_link(l, y, j)) {
flag = 0; break; }
res += flag;
}
return res;
}
ll calc_link(int x, int y, const Value &v) {
//x is ancestor of y
int a = dep[x], b = dep[y];
ll res = (v.sum1 * a - v.sum2) * (b + 1) - v.sum3 * a + v.sum4;
return res + (a + b + 2) * pre2[x] - 2 * pre3[x] - a;
}
ll solve2(int x, int y) {
++tot;
int la = get_lca(x, K), lb = get_lca(y, K);
if (la > lb) swap(la, lb), swap(x, y);
int l = get_lca(x, y), pl = par[l], dl = dep[l];
Value v1, v2;
query(v1, rt[pl], rt[y], 0, dl - 1);
query(v2, rt[pl], rt[x], 0, dl - 1);
ll res = calc_link(pl, y, v1) + calc_link(pl, x, v2) -
2 * (dl * pre2[pl] - pre3[pl]) + dl - 1;
res += ((dep[y] + 1) * v1.sum1 - v1.sum3) * (dep[x] - dl + 1);
int tp = 0;
for (int i = x; i != l; i = par[i]) app[++tp] = i;
app[++tp] = l;
while (tp > 0) {
int i = app[tp--];
if (vis[col[i]] != tot) {
vis[col[i]] = tot;
int mn = dep[y] + 1;
for (int j : pla[col[i]])
if (on_link(l, y, j) && mn > dep[j]) mn = dep[j];
res += (ll)(mn - dl) * (dep[x] - dep[i] + 1);
}
}
return res;
}
unsigned int SA, SB, SC;
unsigned int rng61(){
SA ^= SA << 16;
SA ^= SA >> 5;
SA ^= SA << 1;
unsigned int t = SA;
SA = SB;
SB = SC;
SC ^= t ^ SA;
return SC;
}
void gen(){
read(n, K, SA, SB, SC);
for(int i = 2; i <= K; i++) addedge(par[i] = i - 1, i);
for(int i = K + 1; i <= n; i++)
addedge(par[i] = rng61() % (i - 1) + 1, i);
for(int i = 1; i <= n; i++) col[i] = rng61() % n + 1;
}
int main() {
for (read(T); T--;) {
tot = 0, cls(head), cls(vis), cls(app);
gen();
for (int i = 1; i <= n; i++) pla[i].clear();
dfs(1, ptot = tot = 0);
for (int i = 2; i <= tot; i++) lg[i] = lg[i >> 1] + 1;
for (int i = 1; i < 20; i++)
for (int j = 1; j + (1 << i) - 1 <= tot; j++)
st[i][j] = get_min(st[i - 1][j], st[i - 1][j + (1 << i >> 1)]);
tot = 0;
for (read(m); m--;) {
int a, b, c; read(a, b, c);
if (a == 1) print(solve1(b, c));
else print(solve2(b, c));
}
}
ioflush();
return 0;
}