树状数组之区间最值

树状数组之区间最值

原理

树状数组之区间最值_第1张图片

数学原理:
树状数组之区间最值_第2张图片

建立树状数组

利用上面的性质,在树状数组的尾部插入数据,来建立一个树状数组

void push(int pos){ 
    int i,lb = lowbit(pos); 
    c[pos] = a[pos]; 
    for(i=1;i<lb;i <<=1){ 
        c[pos] = max(c[pos],c[pos-i]); 
     }
}

树的维护

void update(int pos,int v){ 
    int i,lb; 
    c[pos] = a[pos] = v; 
    lb = lowbit(pos); 
    for(i=1;i<lb;i <<=1){ //利用孩子更新自己 
        c[pos] = c[pos] > c[pos-i] ? c[pos] : c[pos-i]; 
    } 
    int pre = c[pos]; 
    pos+=lowbit(pos);//父亲的位置
    /* 更新父亲 */ 
    while(pos <= n){
        if( c[pos] < pre){ //更新的父亲 
            c[pos] = pre; 
            pos +=lowbit(pos); 
        } //没有更新父亲 
        else break; 
     } 
 }

查询最值

   设 query(x,y)query(x,y) 求区间 [x,y] 之间的最值, 已知 c[x] 表示 [x−lowbit(x)+1,x] 之间的最值,那如何求区间 [x,y] 的最值呢?
树状数组之区间最值_第3张图片

我们不难发现:

  • 如果求区间 [1,8] 的最值,就需要点 c[8]
  • 如果求区间 [1,7] 的最值,就需要点 c[7],c[6],c[4]
  • 如果求区间 [2,7] 的最值,就需要点 c[7],c[6],a[4],c[3],a[2]
  • 如果求区间 [2,2] 的最值,就需要点 a[2]如果求区间 [2,8] 的最值,就需要点 a[8],c[7],c[6],a[4],c[3],a[2]

所以,我们发现下面的规律,因为 y−lowbit(y)+1y−lowbit(y)+1 表示 c[y]c[y] 结点所管辖范围的最左边的点若

  • y−lowbit(y)+1>=xy−lowbit(y)+1>=x, 则query(x,y)=max(c[y],query(x,y−lowbit(y)))query(x,y)=max(c[y],query(x,y−lowbit(y)));
  • 若 y−lowbit(y)+1query(x,y)=max(a[y],query(x,y−1))query(x,y)=max(a[y],query(x,y−1));
  • 边界 x>y
int query(int x,int y){
    int res = -1; 
    while(x <= y){ 
        int nx = y - lowbit(y)+1; //最左边的点 
        if(nx >= x ){ 
            res = res < c[y] ? c[y] :res; //判断是否最优
            y = nx-1; // 下一个求解区间 
        } else { // nx < x 
            res = res < a[y] ? a[y] :res; //判断是否最优
            y--;
        } 
    } 
    return res;
}

总结

特点:

  • 每一次在尾部添加一个数值,时间为 log(n)
  • 可以保留原数组的相对应位置不变
  • 如果不进行单点修改,速度会更快

   所以,树状数组求区间最值特别适合那些:一边在尾部添加数据,一边查询的题目

核心代码

const int maxn = 1e6 + 5, maxe = 1e6 + 5; //点与边的数量

int n, m;
int N = maxn;
int a[maxn], c[maxn]; // a是原数组

inline int lowbit(int x) { return x & -x; }
inline int fa(int p) { return p + lowbit(p); }
inline int left(int p) { return p - lowbit(p); }
inline int g(int a, int b) { return a>b ? a : b; }

void update_by_child(int p, int v) { //alias push
    c[p] = a[p] = v;
    int lb = lowbit(p);
    for (int i = 1; i < lb; i <<= 1)
        c[p] = g(c[p], c[p - i]);
}

void update(int p, int v) {
    update_by_child(p, v);
    int t = c[p];
    for (p = fa(p); p <= N; p = fa(p)) {
    if (g(t, c[p])) c[p] = t;
       else break;
    }
}

int query(int l, int r) { // 求区间最值
    int ret = a[l];
    for (; l <= r; ) {
       int next = left(r) + 1;
       if (next >= l) ret = g(ret, c[r]), r = next - 1;
       else            ret = g(ret, a[r]), r--;
    }
    return ret;
}

例题

hdu[1754] I Hate It

思路: 利用树状数组求区间最值

#include
#include
#include
#include

using namespace std;
typedef long long ll;
const int maxn = 1e6 + 5, maxe = 1e6 + 5; //点与边的数量

int n, m;
int N = maxn;
int a[maxn], c[maxn]; // a是原数组

inline int lowbit(int x) { return x & -x; }
inline int fa(int p) { return p + lowbit(p); }
inline int left(int p) { return p - lowbit(p); }
inline int g(int a, int b) { return a>b ? a : b; }

void update_by_child(int p, int v) { //alias push
    c[p] = a[p] = v;
    int lb = lowbit(p);
    for (int i = 1; i < lb; i <<= 1)
        c[p] = g(c[p], c[p - i]);
}

void update(int p, int v) {
    update_by_child(p, v);
    int t = c[p];
    for (p = fa(p); p <= N; p = fa(p)) {
    if (g(t, c[p])) c[p] = t;
       else break;
    }
}

int query(int l, int r) { // 求区间最值
    int ret = a[l];
    for (; l <= r; ) {
       int next = left(r) + 1;
       if (next >= l) ret = g(ret, c[r]), r = next - 1;
       else            ret = g(ret, a[r]), r--;
    }
    return ret;
}

int main() {
    while (1) {
        memset(a, 0, sizeof(a));
        memset(c, 0, sizeof(c));
        if (scanf("%d%d", &n, &m) == EOF) break;
        for (int i = 1; i <= n; ++i) {
            int t;
            scanf("%d", &t);
            update_by_child(i, t); //初始化原数组与树状数组
        }
        char s[10];
        for (int i = 1; i <= m; ++i) {
            scanf("%s", s);
            int x, y;
            if (s[0] == 'Q') {
                scanf("%d%d", &x, &y);
                ll ans = query(x, y);
                printf("%lld\n", ans);
            }
            else {
                scanf("%d%d", &x, &y);
                update(x, y);
            }
        }
    }
    return 0;
}

你可能感兴趣的:(算法,算法,c++)