把一个矩阵化成3个三角形容斥,然后用等差线段树就可以做了...
#include <bits/stdc++.h> using namespace std; typedef long long LL; #define now o, L, R, tree #define lson o << 1, L, mid, tree #define rson o << 1 | 1, mid+1, R, tree #define ls o << 1 #define rs o << 1 | 1 const int maxn = 400005; struct node { LL sum, val, lazy; node(LL sum = 0, LL val = 0, LL lazy = 0) : sum(sum), val(val), lazy(lazy) {} }a[maxn << 2], b[maxn << 2]; int n, m; void pushdown(int o, int L, int R, node tree[]) { if(tree[o].lazy) { int mid = (L + R) >> 1; LL tl = mid-L+1, tr = R-mid; tree[ls].sum += (1 + tl) * tl / 2 * tree[o].lazy; tree[rs].sum += (1 + tr) * tr / 2 * tree[o].lazy; tree[ls].val += tl * tree[o].lazy; tree[rs].val += tr * tree[o].lazy; tree[ls].lazy += tree[o].lazy; tree[rs].lazy += tree[o].lazy; tree[o].lazy = 0; } } void pushup(int o, int L, int R, node tree[]) { int mid = (L + R) >> 1; tree[o].val = tree[ls].val + tree[rs].val; tree[o].sum = tree[ls].sum + tree[rs].sum + tree[rs].val * (mid-L+1); } void build(int o, int L, int R, node tree[]) { tree[o].sum = tree[o].val = tree[o].lazy = 0; if(L == R) return; int mid = (L + R) >> 1; build(lson); build(rson); pushup(now); } void update(int o, int L, int R, node tree[], int ql, int qr) { if(ql <= L && qr >= R) { LL len = R - L + 1; tree[o].lazy += 1; tree[o].val += len; tree[o].sum += (1 + len) * len / 2; return; } pushdown(now); int mid = (L + R) >> 1; if(ql <= mid) update(lson, ql, qr); if(qr > mid) update(rson, ql, qr); pushup(now); } node query(int o, int L, int R, node tree[], int ql, int qr) { if(ql <= L && qr >= R) return tree[o]; pushdown(now); int mid = (L + R) >> 1; node ans; if(qr <= mid) ans = query(lson, ql, qr); else if(ql > mid) ans = query(rson, ql, qr); else { ans = query(lson, ql, qr); node t = query(rson, ql, qr); ans.val += t.val; ans.sum += t.sum; ans.sum += (mid - max(ql, L) + 1) * t.val; } pushup(now); return ans; } LL solve2(int ql, int qr) { ql += n, qr += n; node ans = query(1, 1, 2 * n, b, ql, qr); return ans.sum; } LL solve1(int ql, int qr) { node ans = query(1, 1, 2 * n, a, ql, qr); return ans.sum; } void solve() { int x1, y1, x2, y2; scanf("%d%d%d%d", &x1, &x2, &y1, &y2); LL ans = 0; ans += solve2(x1 - y2, x2 - y1); ans -= solve2(x1 - y1 + 1, x2 - y1); ans -= solve2(x2 - y2 + 1, x2 - y1); ans += solve1(x1 + y1, x2 + y2); ans -= solve1(x2 + y1 + 1, x2 + y2); ans -= solve1(x1 + y2 + 1, x2 + y2); printf("%lld\n", ans); } void work() { scanf("%d%d", &n, &m); build(1, 1, 2 * n, a); build(1, 1, 2 * n, b); while(m--) { int op; scanf("%d", &op); if(op == 1) { int ql, qr; scanf("%d%d", &ql, &qr); update(1, 1, 2 * n, a, ql, qr); } if(op == 2) { int ql, qr; scanf("%d%d", &ql, &qr); ql += n, qr += n; update(1, 1, 2 * n, b, ql, qr); } if(op == 3) solve(); } } int main() { int _; scanf("%d", &_); for(int i = 1; i <= _; i++) { printf("Case #%d:\n", i); work(); } return 0; }