题目大意
构造一棵\([1,n]\)的线段树,有\(q\)个询问\([x,y]\),每次查询\([x,y]\)的所有子区间在线段树上经过的点数之和。
\(n,q \leq 500000\)
Solution
一开始方向错了。。。。
显然线段树上只有和\([x,y]\)有交集的区间才会产生贡献。
设该点代表区间为\([l,r]\):
- 若\([l,r]\)包含\([x,y]\),则\([x,y]\)的所有子区间都会经过该点。
- 若\([l,r]\)与\([x,y]\)相交,那么只有和\([l,r]\)有交集的子区间有贡献,于是用\([x,y]\)所有的子区间减去和\([l,r]\)没有交的子区间。
- 若\([x,y]\)包含\([l,r]\),这时\([x,y]\)的子区间必须和\([l,r]\)有交集且不能包含\([l,r]\)的父区间,同样做一下减法就行了。
现在问题是,第一、二种情况的\([l,r]\)都是\(O(logn)\)个的,可以暴力做。第三种情况,若\([x,y]\)包含了\([l,r]\),\([x,y]\)肯定也包含了\([l,r]\)的所有子区间,不能暴力下去处理答案。如果我们把\([l,r]\)对\([x,y]\)的贡献写出来,是一个和\(x,y\)有关的二次多项式,于是我们可以维护每个区间各项的系数,统计一下子树内系数之和就行了。
Code
#include
#include
#include
#define lson rt << 1
#define rson rt << 1 | 1
using namespace std;
typedef long long ll;
const int N = 500007;
int n, q, opt;
ll l, r, a, b, lastans;
ll C2(ll n) {
return n * (n + 1) / 2;
}
ll sum[N << 2][3];
void pre(int rt, int l, int r, int fl, int fr) {
if (l != 1 || r != n) {
sum[rt][0] = 2 * l - 2 * fr;
sum[rt][1] = 2 * r - 2 * fl;
sum[rt][2] = -l - 1ll * l * l + r - 1ll * r * r + 2ll * fl * fr - 2 * fl + 2 * fr;
}
if (l == r) return;
int mid = l + r >> 1;
pre(lson, l, mid, l, r), pre(rson, mid + 1, r, l, r);
for (int i = 0; i < 3; ++i) sum[rt][i] += sum[lson][i] + sum[rson][i];
}
void go(int rt, int l, int r, int ql, int qr) {
if (l <= ql && qr <= r) lastans += C2(qr - ql + 1);
if (l < ql && r >= ql && r < qr) lastans += C2(qr - ql + 1) - C2(qr - r);
if (l > ql && l <= qr && r > qr) lastans += C2(qr - ql + 1) - C2(l - ql);
if (ql <= l && r <= qr) {
if (ql != l || r != qr) lastans += C2(qr - ql + 1) - C2(l - ql) - C2(qr - r);
if (l != r) lastans += ((sum[lson][0] + sum[rson][0]) * ql + (sum[lson][1] + sum[rson][1]) * qr + sum[lson][2] + sum[rson][2]) / 2;
return;
}
int mid = l + r >> 1;
if (ql <= mid) go(lson, l, mid, ql, qr);
if (mid + 1 <= qr) go(rson, mid + 1, r, ql, qr);
}
int main() {
freopen("ran.in", "r", stdin);
freopen("ran.out", "w", stdout);
scanf("%d%d%d", &n, &q, &opt);
pre(1, 1, n, 1, n);
while (q--) {
scanf("%lld%lld", &l, &r);
a = (l ^ (lastans * opt)) % n + 1, b = (r ^ (lastans * opt)) % n + 1;
l = min(a, b), r = max(a, b), lastans = 0;
go(1, 1, n, l, r);
printf("%lld\n", lastans);
}
return 0;
}