题目:
给你一个数组A[1::n],初始时每个元素都为零。对数组完成一些操作:
第一种可能,给你两个数p 和x(1<= p <= n),把数组的第p 个元素替换为x,即A[p] =x.
第二种可能,给你两个数L 和R(1 <=L <= R <= n),请给出A[L]; A[L + 1]; : : : ; A[R] 这几个数中去掉一个最大值和一个最小值后剩下的数的和是多少.
题解:
单点修改、区间查询的线段树可以在O(log n) 的时间内维护区间的和、最大值、最小值.
此题用三个线段树维护各个区间的和、最大值、最小值,对每个查询输出对应区间的sum . max . min 即可.
#include
#include
using namespace std;
typedef long long ll;
const int maxn = 1e6 + 5;
struct node {
int left, right;
int mx, mn;
ll sum;
}rt[maxn * 4];
int tl(int x) { return x * 2; }
int tr(int x) { return x * 2 + 1; }
void pushup(int x) {
rt[x].sum = rt[tl(x)].sum + rt[tr(x)].sum;
rt[x].mx = max(rt[tl(x)].mx, rt[tr(x)].mx);
rt[x].mn = min(rt[tl(x)].mn, rt[tr(x)].mn);
}
void build(int x,int L,int R) {
rt[x].left = L; rt[x].right = R;
if (L == R) {
rt[x].sum = rt[x].mn = rt[x].mx = 0;
return;
}
int mid = L + R >> 1;
build(tl(x), L, mid); build(tr(x), mid + 1, R);
pushup(x);
}
void upd(int x, int pos, ll val) {
if (rt[x].left == rt[x].right) {
rt[x].sum = rt[x].mn = rt[x].mx = val;
return;
}
int mid = rt[x].left + rt[x].right >> 1;
if (pos <= mid)upd(tl(x), pos, val);
else upd(tr(x), pos, val);
pushup(x);
}
ll qury(int x, int L, int R) {
if (L == rt[x].left&&R == rt[x].right) {
return rt[x].sum;
}
int mid = rt[x].left + rt[x].right >> 1;
if (R <= mid) return qury(tl(x), L, R);
else if (L > mid) return qury(tr(x), L, R);
else return qury(tl(x), L, mid) + qury(tr(x), mid + 1, R);
}
ll fmin(int x, int L, int R) {
if (L == rt[x].left&&R == rt[x].right) {
return rt[x].mn;
}
int mid = rt[x].left + rt[x].right >> 1;
ll ret = 0x3fffffffffffff;
if (R <= mid) return min(ret, fmin(tl(x), L, R));
else if (L > mid) return min(ret, fmin(tr(x), L, R));
else return min(fmin(tl(x), L, mid), fmin(tr(x), mid + 1, R));
}
ll fmax(int x, int L, int R) {
if (L == rt[x].left&&R == rt[x].right) {
return rt[x].mx;
}
int mid = rt[x].left + rt[x].right >> 1;
ll ret = -0x3ffffffffffffff;
if (R <= mid) return max(ret, fmax(tl(x), L, R));
else if (L > mid) return max(ret, fmax(tr(x), L, R));
else return max(fmax(tl(x), L, mid), fmax(tr(x), mid + 1, R));
}
int main() {
int n, m;
scanf("%d%d", &n, &m);
build(1, 1, n);
while (m--) {
int o, x, y;
scanf("%d%d%d", &o, &x, &y);
if (o == 0) upd(1, x, y);
else printf("%lld\n", qury(1, x, y) - fmax(1, x, y) - fmin(1, x, y));
}
return 0;
}