树状数组专题入门——POJ 2352,1556,2155,3321,1990,2309,1195,2481,2299,3067

————欢迎探讨交流,如果什么不足和错误请指出。

从白书上学习到前缀和,到写了这么些题目,目前看来树状数组最主要用途就是低复杂度的点,区间修改和查询,lower_bit真的是很神奇的东西,树状数组就是靠着这个东西加减最低位1来维护执行add操作和sum操作,通过BST POJ - 2309 这道题目可以加深对lower_bit和树状数组原理的了解。

这张图并没有完全展示了树状数组的结点关系,只展示了部分(例如12应该连到16上,10连到8上),在树状数组中,每个结点都负责统计了他的左子树的数值,我们修改点12时通过add函数也会修改到点16(因为12在16的左子树里),然而不会修改到点8,这就保证了整个数据储存的正确性了,而如果我们统计1到10的数值和,10只负责到他的左子树9,所以就得通过减去lower_bit来找上一个负责点,这里找到的是8,而点8刚好负责了1到8的全部,从lower_bit上看也是,8-lower_bit(8)=0,所以到点8就统计完了1到10区间了,这也就是树状数组的高效性所在,比遍历效率不知道高到哪里去了。所以简而言之,add(node x)的过程就是不断的向右上寻找把所有负责到这个点x的结点的值都更新,而sum的过程就是不断向左上寻找结点使得这些结点的负责区域合起来刚好就是你要查找的区间。


接下来是一些基础的入门题目。
树状数组一个很典型的应用题目就是求数列逆序数之和,也就是相邻交换变成递增的最少次数,很多题目都是有很相似的思路,一边求解一边维护树状数组,求逆序数之和就是从后往前维护,每维护到一个数就找寻比他小的数有几个,然后在把当前数add进树状数组里。

1.Stars POJ - 2352

数据的给出是按y递增的,而我们要找的就是比当前点x小并且y小的点的数量,处理一个点时,之前的点y都不比他大,只需要用树状数组维护并查找x比他小的点的数量就行了。

#include
#include
#include
using namespace std;
int level[50000];
int cc[50000];
int lowbit(int x){ return x&(-x); }
int sum(int x){
    int rrr = 0;
    while (x > 0){
        rrr += cc[x]; x -= lowbit(x);
    }
    return rrr;
}
void add(int x){
    while (x < 32005){
        cc[x]++;
        x += lowbit(x);
    }
}
int main(){
    int x,y,n;
    scanf("%d", &n);
        //memset(lev, 0, sizeof(lev));
        //memset(c, 0, sizeof(c));
        for (int i = 0; i < n; i++){
            scanf("%d%d", &x, &y);
            level[sum(x+1)]++;
            add(x+1);
        }
        for (int i =0; i printf("%d\n", level[i]);
        }

    return 0;
}

Color the ball HDU - 1556

区间上色,有的人可能一开始不大明白树状数组怎么进行的区间操作,其实因为树状数组本身就是统计区域和的,所以你在x加1之后,x之后的点进行sum(n)操作(统计1到n)都会统计到这个1,因为sum(n)函数的本质是统计1到n区间的总和,所以如果我们在x add 1, y(>x) add -1,那么y之后的点执行sum函数,既会访问到-1又会访问到1就抵消掉了,而x到y这段区间的点进行sum不会访问到-1,就得到了一次染色的统计。这类题目里面我们的行为和使用前缀和有一定的相似度,在区间首加上value,尾部+1减去value。统计时便保证了之后该区间被加上了value。

#include
#include
#include
#include
using namespace std;
int cc[100005];
int n;
int nn[100005];
int lowbit(int x){ return x&(-x); }
int sum(int x){
    int rrr = 0;
    while (x > 0){
        rrr += cc[x]; x -= lowbit(x);
    }
    return rrr;
}
void add(int x,int d){
    while (x 1){
        cc[x]+=d;
        x += lowbit(x);
    }
}

int main(){
    int x, y;
    while (scanf("%d", &n), n){
        memset(cc, 0, sizeof(cc));
        for (int i = 0; i < n; i++){
            scanf("%d%d", &x, &y);
            add(x, 1);
            add(y + 1, -1);
        }
        for (int i = 1; i <= n; i++){
            printf("%d", sum(i));
            if (i < n)printf(" ");
        }
        printf("\n");
    }
    return 0;
}

Matrix POJ - 2155

也是区间操作点查询,从一维拓展到了二维,行为也和处理二维前缀和很相似,add函数和sum函数都要相应的改变。add里执行的操作也从加法变成了异或。

#include
#include
#include
#include
using namespace std;
const int maxn = 1024;
int n,k;
bool map[maxn+1][maxn+1];
int lowbit(int x){ return x&(-x); }
int sum(int x,int y){
    bool rrr = 0;
    int yy = y;
    while (x > 0){
        y = yy;
        while (y > 0){
            rrr ^= map[x][y];
            y -= lowbit(y);
        }
        x -= lowbit(x);
    }
    return rrr;
}
void add(int x, int y){
    int yy = y;
    while (x <=maxn){
        y = yy;
        while (y <=maxn){
            map[x][y] ^= 1;
            y += lowbit(y);
        }
    x += lowbit(x);
    }
}
char chr;
int main(){
    int x1, y1, x2, y2;
    int t;
    scanf("%d", &t);
    while (t--){
        scanf("%d%d", &n, &k);
        memset(map, 0, sizeof(map));
        for (int i = 0; i < k; i++){
            cin >> chr;
            if (chr == 'C'){
                scanf("%d%d%d%d", &x1, &y1, &x2, &y2);
                add(x1, y1);
                add(x1, y2 + 1);
                add(x2 + 1, y1);
                add(x2 + 1, y2+1);
            }
            else if (chr == 'Q'){
                scanf("%d%d", &x1, &y1);
                printf("%d\n",sum(x1, y1));
            }
        }

        if (t != 0)printf("\n");
    }
    return 0;
}

Mobile phones POJ - 1195

也是二维的题目,不过变成了点操作,区间查询,就比较不适合在add的时候用区间首加区间尾减的方法了,只需要在查询时候做一点修改就行了, +sum(z+1, d+1)是查询(0,0)到(z,d)的总和,-sum(z+1, y)是减去(0,0)到(z,y-1)的总和,以此类推,加加减减之后刚好就剩下你要的区间了。

#include 
#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;
#define N 20005

int n;
int c[1030][1030];
int lowbit(int x){ return x&-x; }

int sum(int x,int y){
    int ans = 0;
    for (int i = x; i > 0; i -= lowbit(i)){
        for (int j = y; j > 0; j -= lowbit(j)){
            ans += c[i][j];
        }
    }
    return ans;
}

void add(int x,int y, int val){
    for (int i = x; i < 1030; i += lowbit(i)){
        for (int j = y; j < 1030; j += lowbit(j)){
            c[i][j] += val;
        }
    }
}


int st;
int main(){
    int x, y, z, d;
    scanf("%d%d", &st, &n);
    while (scanf("%d", &st), st != 3){
        if (st == 1){
            scanf("%d%d%d", &x, &y, &z);
            add(x+1, y+1, z);
        }
        else{
            scanf("%d%d%d%d", &x, &y, &z, &d);
            int ans = 0;
            ans += sum(z+1, d+1);
            ans -= sum(z+1, y);
            ans -= sum(x, d+1);
            ans += sum(x, y);
            printf("%d\n", ans);
        }
    }
    return 0;
}

Apple Tree POJ - 3321

这题的难点就是怎么把树的结构转换成适合树状数组方便统计的线性结构,做法就是通过dfs深搜树来对结点编号,对于每个结点x,lef[x]表示这个结点在线性表的位置,rig[x]表示这个结点的子节点在线性表最右端的位置,lef[x]到rig[x]这个区间就是结点x子树的统计结果了。

#include
#include
#include
#include
#include
#include
#include
using namespace std;
#define N 100005
int n, m;
vector< vector<int> > e(N);
int tot;
int lef[N], rig[N], node[N],s[N];
char cc[10];
int lower_bit(int n){ return n&-n; }

void dfs(int cnt){
    lef[cnt] = tot;
    for (int i = 0; i < e[cnt].size(); i++){
        tot++;
        dfs(e[cnt][i]);
    }
    rig[cnt] = tot;
}

void add(int x, int val){
    while (x <= n){
        node[x] += val;
        x += lower_bit(x);
    }
}

int query(int x){
    int ans = 0;
    while (x>0){
        ans += node[x];
        x -= lower_bit(x);
    }
    return ans;
}

int main(){
    int x, y;
    while (~scanf("%d", &n)){
        for (int i = 0; i memset(node, 0, sizeof(node));
        memset(lef, 0, sizeof(lef));
        memset(rig, 0, sizeof(rig));
        memset(s, 0, sizeof(s));
        for (int i = 1; i < n; i++){
            scanf("%d%d", &x, &y);
            e[x].push_back(y);
        //  e[y].push_back(x);
        }
        tot = 1;
        dfs(1);
        for (int i = 1; i <= n; i++){
            add(i, 1);
            s[i] = 1;
        }
        scanf("%d", &m);
        for (int i = 0; i < m; i++){
            scanf("%s%d", cc, &x);
            if (cc[0] == 'Q'){
                printf("%d\n", query(rig[x]) - query(lef[x] - 1));
            }
            else{
                if (s[x] == 0){
                    add(lef[x], 1);
                }
                else{
                    add(lef[x], -1);
                }
                s[x] = !s[x];
            }
        }
    }
    return 0;
}

MooFest POJ - 1990

两头牛之间交流所需要的vol是它们距离*两头牛耳聋程度的最大值,最简单暴力的方法就是枚举两头牛来计算,复杂度自然不满足数据量和时间的要求了,通过对耳聋程度排序,然后维护两个树状数组,一个用来统计当前小于位置x的牛的数量和大于x的数量,另一个用来统计当前小于x的位置总和和大于x的位置总和,就可以比较容易地求出答案了。

#include 
#include 
#include 
#include 
using namespace std;
#define N 20005
struct cow{
    int pos;
    int hear;
}p[N];
int n;
int c[N+1];
int lr[N+1];
int lowbit(int x){ return x&-x; }

long long sum(int *node,int x){
    long long ans = 0;
    while (x>0){
        ans += node[x];
        x -= lowbit(x);
    }
    return ans;
}

void add(int *node,int x, int val){
    while (x <= N){
        node[x] += val;
        x += lowbit(x);
    }
}

bool cmp(cow a, cow b){
    if (a.hear == b.hear)return a.pos < b.pos;
    return a.hear < b.hear;
}

int main(){
    scanf("%d", &n);
    for (int i = 0; i < n; i++){
        scanf("%d%d", &p[i].hear, &p[i].pos);
    }
    sort(p, p + n, cmp);
    add(c,p[0].pos,p[0].pos);
    add(lr, p[0].pos, 1);
    long long ans = 0;
    for (int i = 1; i < n; i++){
        long long lef = sum(lr, p[i].pos), rig = sum(lr, N) - sum(lr, p[i].pos);
        ans += p[i].hear*(lef*p[i].pos - sum(c, p[i].pos) + (sum(c,N)-sum(c,p[i].pos)-rig*p[i].pos));
        add(c, p[i].pos, p[i].pos);
        add(lr, p[i].pos, 1);
    }
    printf("%lld", ans);
    return 0;
}

Cows POJ - 2481

通过favourite clover range来计算牛与牛之间的强壮比较,要求统计每头牛有多少头牛比他强壮,思想跟逆序和 的比较相似,按左区间排序或者右区间排序,然后顺序维护和统计,注意对区间刚好重叠的牛的处理。

#include 
#include 
#include 
#include 
using namespace std;
#define N 100005

int n;
int c[N + 1];
int ans[N + 1];
struct cow{
    int l, r, id;
}p[100005];

int lowbit(int x){ return x&-x; }

int sum(int x){
    int ans = 0;
    while (x>0){
        ans += c[x];
        x -= lowbit(x);
    }
    return ans;
}

void add(int x, int val){
    while (x <= N){
        c[x] += val;
        x += lowbit(x);
    }
}

bool cmp(cow a, cow b){
    if (a.l == b.l)return a.r > b.r;
    return a.l < b.l;
}

int main(){
    while (scanf("%d", &n), n){
        memset(c, 0, sizeof(c));
        for (int i = 0; i < n; i++){
            scanf("%d%d", &p[i].l, &p[i].r);
            p[i].l++, p[i].r++;
            p[i].id = i;
        }
        sort(p, p + n, cmp);
        int cnt = 0;
        for (int i = 0; i < n; i++){
            int ll = p[i].l, rr = p[i].r;
            if (i>0 && p[i - 1].l == ll&&p[i - 1].r == rr){ cnt++; }
            else{ cnt = 0; }
            ans[p[i].id] = sum(N) - sum(rr - 1)-cnt;
            add(rr, 1);
        }
        for (int i = 0; i < n; i++){
            printf("%d", ans[i]);
            if (i < n - 1)printf(" ");
        }
        printf("\n");
    }
    return 0;
}

Ultra-QuickSort POJ - 2299

就是求逆序数的= =,不过要离散化,因为数据的大小上限过大,树状数组记录的数据大小上限就是数组能开的最大值(我这么认为),n又比较小,所以离散化一下就能把数据都处理在限定范围内了,注意会有相同大小的数字以及注意溢出。

#include 
#include 
#include 
#include 
using namespace std;
#define N 5000005

int n;
int c[N + 1];
int nu[N];
struct nn{
    int num, id;
}num[N];
int lowbit(int x){ return x&-x; }

long long sum(int x){
    long long ans = 0;
    while (x>0){
        ans += c[x];
        x -= lowbit(x);
    }
    return ans;
}

void add(int x, int val){
    while (x <= N){
        c[x] += val;
        x += lowbit(x);
    }
}

bool cmp(nn a, nn b){
    return a.num < b.num;
}

int main(){
    while (scanf("%d", &n), n){
        long long ans = 0;
        memset(c, 0, sizeof(c));
        for (int i = 0; i < n; i++){
            scanf("%d", &num[i].num);
            num[i].id = i;
        }
        sort(num, num + n, cmp);
        int tot = 1;
        for (int i = 0; i < n; i++){
            if (i>0 && num[i].num == num[i - 1].num){ nu[num[i].id] = tot; }
            else nu[num[i].id] = tot++;
        }
        for (int i = n - 1; i >= 0; i--){
            ans += sum(nu[i]-1);
            add(nu[i], 1);
        }
        printf("%lld\n", ans);
    }
    return 0;
}

Japan POJ - 3067

东海岸与西海岸的城市修桥,问桥之前交叉的次数,对于两座桥a,b,如果a的左城市编号小于b的左城市编号,右城市标号又大于b的右城市编号,就会产生一次交叉,直接思路就是顺序从下到上枚举左边城市的桥,维护并查询比当前桥右城市编号小的就行(因为当前树状数组里记录的桥的左城市编号都比当前桥大)。注意的是同一城市发出的桥不会相交,所以处理是一个城市处理完再把整个城市的桥都add到树状数组里。

#include 
#include 
#include 
#include 
#include 
using namespace std;
#define N 1005
vector<int> e[1005];
int n,m,k;
int c[N + 1];
struct nn{
    int num, id;
}num[N];
int lowbit(int x){ return x&-x; }

long long sum(int x){
    long long ans = 0;
    while (x>0){
        ans += c[x];
        x -= lowbit(x);
    }
    return ans;
}

void add(int x, int val){
    while (x <= N){
        c[x] += val;
        x += lowbit(x);
    }
}

bool cmp(nn a, nn b){
    return a.num < b.num;
}

int main(){
    int x, y;
    int t;
    scanf("%d", &t);
    for (int cas = 1; cas <= t; cas++){
        memset(c, 0, sizeof(c));
        for (int i = 1; i <= n; i++)e[i].clear();
        scanf("%d%d%d", &n, &m, &k);
        for (int i = 0; i < k; i++){
            scanf("%d%d", &x, &y);
            e[x].push_back(y);
        }
        long long ans = 0;
        for (int i = n; i >=1; i--){
            for (int j = 0; j < e[i].size(); j++){
                ans += sum(e[i][j] - 1);
            }
            for (int j = 0; j < e[i].size(); j++){
                add(e[i][j], 1);
            }
        }
        printf("Test case %d: %lld\n", cas, ans);
    }
    return 0;
}

你可能感兴趣的:(数据结构)