题目链接
题意:区间修改一个添加等差数列,一个把区间设为某个值,然后询问区间和
思路:关键在于等差数列的地方,线段树的每个结点添加一个首项和公差,由于等差数列加上一个等差数列还是一个等差数列,利用这个性质就可以进行维护了,注意set操作会覆盖掉等差数列的操作
代码:
#include <cstdio> #include <cstring> #define lson(x) ((x<<1)+1) #define rson(x) ((x<<1)+2) typedef long long ll; const int N = 250005; int n; struct Node { ll l, r, a1, d, c, val; int setc; } node[N * 4]; void build(ll l, ll r, int x = 0) { node[x].l = l; node[x].r = r; node[x].a1 = node[x].d = node[x].val = node[x].setc = 0; if (l == r) return; ll mid = (l + r) / 2; build(l, mid, lson(x)); build(mid + 1, r, rson(x)); } void pushup(int x) { node[x].val = node[lson(x)].val + node[rson(x)].val; } void pushdown(int x) { if (node[x].setc) { node[lson(x)].c = node[rson(x)].c = node[x].c; node[lson(x)].val = (node[lson(x)].r - node[lson(x)].l + 1) * node[x].c; node[rson(x)].val = (node[rson(x)].r - node[rson(x)].l + 1) * node[x].c; node[lson(x)].setc = node[rson(x)].setc = 1; node[lson(x)].a1 = node[lson(x)].d = node[rson(x)].a1 = node[rson(x)].d = 0; node[x].setc = 0; } node[lson(x)].a1 += node[x].a1; node[lson(x)].d += node[x].d; ll l = node[x].l, r = node[x].r; ll mid = (l + r) / 2; ll amid = node[x].a1 + node[x].d * (mid - l + 1); ll len1 = (mid - l + 1), len2 = (r - mid); node[lson(x)].val += node[x].a1 * len1 + len1 * (len1 - 1) / 2 * node[x].d; node[rson(x)].a1 += amid; node[rson(x)].d += node[x].d; node[rson(x)].val += amid * len2 + len2 * (len2 - 1) / 2 * node[x].d; node[x].a1 = node[x].d = 0; } void A(ll l, ll r, ll d, int x = 0) { if (node[x].l >= l && node[x].r <= r) { ll st = node[x].l - l + 1; if (d == -1) st = r - node[x].l + 1; node[x].a1 += st; node[x].d += d; ll len = node[x].r - node[x].l + 1; node[x].val += st * len + len * (len - 1) / 2 * d; return; } pushdown(x); ll mid = (node[x].l + node[x].r) / 2; if (l <= mid) A(l, r, d, lson(x)); if (r > mid) A(l, r, d, rson(x)); pushup(x); } void C(ll l, ll r, ll c, int x = 0) { if (node[x].l >= l && node[x].r <= r) { node[x].setc = 1; node[x].c = c; node[x].val = (node[x].r - node[x].l + 1) * c; node[x].a1 = node[x].d = 0; return; } pushdown(x); ll mid = (node[x].l + node[x].r) / 2; if (l <= mid) C(l, r, c, lson(x)); if (r > mid) C(l, r, c, rson(x)); pushup(x); } ll S(ll l, ll r, int x = 0) { if (node[x].l >= l && node[x].r <= r) return node[x].val; pushdown(x); ll mid = (node[x].l + node[x].r) / 2; ll ans = 0; if (l <= mid) ans += S(l, r, lson(x)); if (r > mid) ans += S(l, r, rson(x)); pushup(x); return ans; } int main() { while (~scanf("%d", &n)) { build(1, 250000); ll a, b, c; char Q[2]; while (n--) { scanf("%s%lld%lld", Q, &a, &b); if (Q[0] == 'C') scanf("%lld", &c); if (Q[0] == 'A') A(a, b, 1); if (Q[0] == 'B') A(a, b, -1); if (Q[0] == 'C') C(a, b, c); if (Q[0] == 'S') printf("%lld\n", S(a, b)); } } return 0; }