有一种树叫做线段树,有一种数组叫做树状数组

近日受到微软编程之美大赛第二题和hdu一些题目变态般的大数据的刺激,而且老是听到群里的一些大神讲什么线段树,树状数组,分桶法呀等等一系列不明觉厉的东西,花了几天好好看了下线段树和树状数组,下面我来分享一些,我的心得和感悟,如有不足之处欢迎大神们前来狂喷。

微软编程之美初赛第一场树题解http://blog.csdn.net/asdfghjkl1993/article/details/24306921

线段树和树状数组都是一种擅长处理区间的数据结构。它们间最大的区别之一就是线段树是一颗完美二叉树,而树状数组(BIT)相当于是线段树中每个节点的右儿子去掉。

如图:

线段树

 

 有一种树叫做线段树,有一种数组叫做树状数组

树状数组:

有一种树叫做线段树,有一种数组叫做树状数组

 

 

树状数组一般适用于三类问题:

1,修改一个点求一个区间

2,修改一个区间求一个点

3,求逆序列对

 

而用树状数组能够解决的问题,用线段树肯定能够解决,反之则不一定。但是树状数组有一个明显的好处就是较为节省空间,实现要比线段树要容易得多,而且在处理某些问题的时候使用树状数组效率反而会高得多。 昨天看到某位大牛在博客上也留下了这样一句话,线段树擅长处理横向区间的问题,树状数组擅长处理纵向区间的问题,可能由于水平有限,暂时还木有体会到这一点。。。。忧伤。。。

 

下面我们来看两道比较基础的线段树模板题

 

首先是点修改的:

 

一次修改一个点,然后查询最大值还有和:

 

 1 void update(int u,int v,int o,int l,int r)

 2 

 3 {

 4 

 5 int m=(l+r)/2;

 6 

 7 if(l==r)

 8 

 9 {

10 

11 maxv[o]=v;

12 

13 sum[o]=v;

14 

15 }

16 

17 else

18 

19 {

20 

21 if(u<=m)

22 

23 update(u,v,o*2,l,m);

24 

25 else

26 

27 update(u,v,o*2+1,m+1,r);

28 

29 maxv[o]=max(maxv[o*2],maxv[o*2+1]);

30 

31 sum[o]=sum[o*2]+sum[o*2+1];

32 

33 }

34 

35 }

36 

37 int query_sum(int ql,int qr,int o,int l,int r)

38 

39 {

40 

41 int m=(l+r)/2;

42 

43 if(ql<=l&&r<=qr)

44 

45 return sum[o];

46 

47 if(ql<=m)

48 

49 return query_sum(ql,qr,o*2,l,m);

50 

51 if(m<qr)

52 

53 return query_sum(ql,qr,o*2+1,m+1,r);

54 

55 }

56 

57 int query_max(int ql,int qr,int o,int l,int r)

58 

59 {

60 

61 int m=(l+r)/2,ans=-1;

62 

63 if(ql<=l&&r<=qr)

64 

65 return maxv[o];

66 

67 if(ql<=m)

68 

69 return max(ans,query_max(ql,qr,o*2,l,m));

70 

71 if(m<qr)

72 

73 return max(ans,query_max(ql,qr,o*2+1,m+1,r));

74 

75 }

 


 

 

然后是区间修改的:

 

Uva11992这道题是刘汝佳厚白书中的例题

题目链接:http://uva.onlinejudge.org/index.php?option=com_onlinejudge&Itemid=8&page=show_problem&problem=3143

大意为对一个矩阵进行操作,选择其中子矩阵(x1,y1,x2,y2)可以让它每个元素增加v

也可以让它每个元素等于v,也可以查询这个子矩阵的元素和,最小值,最大值。

解决方法当然是线段树,不过对于这棵线段树的update,对于set操作要请除节点上的

Addv标记,但对于add操作不清楚setv标记,在maintain函数中先考虑setv再考虑addv

而在query中要综合考虑setv和addv.

  1 #include<iostream>

  2 

  3 #include<cstdio>

  4 

  5 #include<cstring>

  6 

  7 #include<algorithm>

  8 

  9 using namespace std;

 10 

 11  

 12 

 13 const int maxnode = 1<<17;

 14 

 15  

 16 

 17 int _sum, _min, _max, op, x1, x2, y1, y2, x, v;

 18 

 19  

 20 

 21 class IntervalTree {

 22 

 23   int sumv[maxnode], minv[maxnode], maxv[maxnode], setv[maxnode], addv[maxnode];

 24 

 25  

 26 

 27   // 维护节点o

 28 

 29   void maintain(int o, int L, int R) {

 30 

 31     int lc = o*2, rc = o*2+1;

 32 

 33     if(R > L) {

 34 

 35       sumv[o] = sumv[lc] + sumv[rc];

 36 

 37       minv[o] = min(minv[lc], minv[rc]);

 38 

 39       maxv[o] = max(maxv[lc], maxv[rc]);

 40 

 41     }

 42 

 43     if(setv[o] >= 0) { minv[o] = maxv[o] = setv[o]; sumv[o] = setv[o] * (R-L+1); }

 44 

 45     if(addv[o]) { minv[o] += addv[o]; maxv[o] += addv[o]; sumv[o] += addv[o] * (R-L+1); }

 46 

 47   }

 48 

 49  

 50 

 51   //标记传递

 52 

 53   void pushdown(int o) {

 54 

 55     int lc = o*2, rc = o*2+1;

 56 

 57     if(setv[o] >= 0) {

 58 

 59       setv[lc] = setv[rc] = setv[o];

 60 

 61       addv[lc] = addv[rc] = 0;

 62 

 63       setv[o] = -1; // 清楚标记

 64 

 65     }

 66 

 67     if(addv[o]) {

 68 

 69       addv[lc] += addv[o];

 70 

 71       addv[rc] += addv[o];

 72 

 73       addv[o] = 0; // Çå³ý±¾½áµã±ê¼Ç

 74 

 75     }

 76 

 77   }

 78 

 79  

 80 

 81   void update(int o, int L, int R) {

 82 

 83     int lc = o*2, rc = o*2+1;

 84 

 85     if(y1 <= L && y2 >= R) { // 在区间内

 86 

 87       if(op == 1) addv[o] += v;

 88 

 89       else { setv[o] = v; addv[o] = 0; }

 90 

 91     } else {

 92 

 93       pushdown(o);

 94 

 95       int M = L + (R-L)/2;

 96 

 97       if(y1 <= M) update(lc, L, M); else maintain(lc, L, M);

 98 

 99       if(y2 > M) update(rc, M+1, R); else maintain(rc, M+1, R);

100 

101     }

102 

103     maintain(o, L, R);

104 

105   }

106 

107  

108 

109   void query(int o, int L, int R, int add) {

110 

111     if(setv[o] >= 0) {

112 

113       int v = setv[o] + add + addv[o];

114 

115       _sum += v * (min(R,y2)-max(L,y1)+1);

116 

117       _min = min(_min, v);

118 

119       _max = max(_max, v);

120 

121     } else if(y1 <= L && y2 >= R) {

122 

123       _sum += sumv[o] + add * (R-L+1);

124 

125       _min = min(_min, minv[o] + add);

126 

127       _max = max(_max, maxv[o] + add);

128 

129     } else {

130 

131       int M = L + (R-L)/2;

132 

133       if(y1 <= M) query(o*2, L, M, add + addv[o]);

134 

135       if(y2 > M) query(o*2+1, M+1, R, add + addv[o]);

136 

137     }

138 

139   }

140 

141 };

142 

143  

144 

145 const int maxr = 20 + 5;

146 

147 const int INF = 1000000000;

148 

149  

150 

151 IntervalTree tree[maxr];

152 

153  

154 

155 int main() {

156 

157   int r, c, m;

158 

159   while(scanf("%d%d%d", &r, &c, &m) == 3) {

160 

161     memset(tree, 0, sizeof(tree));

162 

163     for(x = 1; x <= r; x++) {

164 

165       memset(tree[x].setv, -1, sizeof(tree[x].setv));

166 

167       tree[x].setv[1] = 0;

168 

169     }

170 

171     while(m--) {

172 

173       scanf("%d%d%d%d%d", &op, &x1, &y1, &x2, &y2);

174 

175       if(op < 3) {

176 

177         scanf("%d", &v);

178 

179         for(x = x1; x <= x2; x++) tree[x].update(1, 1, c);

180 

181       } else {

182 

183         _sum = 0; _min = INF; _max = -INF;

184 

185         for(x = x1; x <= x2; x++) tree[x].query(1, 1, c, 0);

186 

187         printf("%d %d %d\n", _sum, _min, _max);

188 

189       }

190 

191     }

192 

193   }

194 

195   return 0;

196 

197 }

 


 

再来看看树状数组的

 

先来个改点求区间的

看看hdu1161

题目链接:

http://acm.hdu.edu.cn/showproblem.php?pid=1166

题目大意:给n个初始数据构建一棵树状数组,然后进行查询求和等一些列操作。

标准模板题,不解释。

 

  1 #include<iostream>

  2 

  3 #include<algorithm>

  4 

  5 #include<cstring>

  6 

  7 #include<cstdio>

  8 

  9 #include<cmath>

 10 

 11 using namespace std;

 12 

 13 const int MAX=50005;

 14 

 15 int N;

 16 

 17 class BIT

 18 

 19 {

 20 

 21 private:

 22 

 23     int bit[MAX];

 24 

 25     int lowbit(int t)

 26 

 27     {

 28 

 29         return t&-t;

 30 

 31     }

 32 

 33 public:

 34 

 35     BIT()

 36 

 37     {

 38 

 39         memset(bit,0,sizeof(bit));

 40 

 41     }

 42 

 43     int sum(int i)

 44 

 45     {

 46 

 47         int s=0;

 48 

 49         while(i>0)

 50 

 51         {

 52 

 53             s+=bit[i];

 54 

 55             i-=lowbit(i);

 56 

 57         }

 58 

 59         return s;

 60 

 61     }

 62 

 63     void add(int i,int v)

 64 

 65     {

 66 

 67         while(i<=N)

 68 

 69         {

 70 

 71             bit[i]+=v;

 72 

 73             i+=lowbit(i);

 74 

 75         }

 76 

 77     }

 78 

 79 };

 80 

 81 int main()

 82 

 83 {

 84 

 85     int T;

 86 

 87     while(cin>>T)

 88 

 89     {

 90 

 91         for(int t=1;t<=T;t++)

 92 

 93         {

 94 

 95             printf("Case %d:\n",t);

 96 

 97             cin>>N;

 98 

 99             BIT tree;

100 

101             for(int i=1;i<=N;i++)

102 

103             {

104 

105                 int x;

106 

107                 cin>>x;

108 

109                 tree.add(i,x);

110 

111             }

112 

113             char ord[15];

114 

115             while(scanf("%s",ord)&&strcmp(ord,"End"))

116 

117             {

118 

119                 int a,b;

120 

121                 scanf("%d%d",&a,&b);

122 

123                 switch(ord[0])

124 

125                 {

126 

127                 case 'Q':

128 

129                     printf("%d\n",tree.sum(b)-tree.sum(a-1));

130 

131                     break;

132 

133                 case 'A':

134 

135                     tree.add(a,b);

136 

137                     break;

138 

139                 case 'S':

140 

141                     tree.add(a,-b);

142 

143                     break;

144 

145                 }

146 

147             }

148 

149         }

150 

151     }

152 

153     return 0;

154 

155 }
View Code

 


 

再看一道修改区间,然后单点查询的

看hdu 1556

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=1556

N个气球排成一排,从左到右依次编号为1,2,3....N.每次给定2个整数a b(a <= b),lele便为骑上他的“小飞鸽"牌电动车从气球a开始到气球b依次给每个气球涂一次颜色。但是N次以后lele已经忘记了第I个气球已经涂过几次颜色了,你能帮他算出每个气球被涂过几次颜色吗?

 

这题是修改区间的,单点查询的,则要注意一点 先对左区间进行操作add(a,1),然后对右边区间进行操作add(b+1,-1),把不该修改的那部分值再修改回来,即实现了对一个区间的值的修改。然后通过sum(i),即可求点(如果有人问为什么是sum(i)而不是bit[i]呢?我只能说你太天真了。。。。自己再纸上画画就能知道。。。。)

 

  1 #include<iostream>

  2 

  3 #include<algorithm>

  4 

  5 #include<cstdio>

  6 

  7 #include<cstring>

  8 

  9 using namespace std;

 10 

 11 const int MAX=100001;

 12 

 13 int N;

 14 

 15 class BIT2

 16 

 17 {

 18 

 19 private:

 20 

 21     int bit[MAX];

 22 

 23     int lowbit(int t)

 24 

 25     {

 26 

 27         return t&-t;

 28 

 29     }

 30 

 31 public:

 32 

 33     BIT2()

 34 

 35     {

 36 

 37         memset(bit,0,sizeof(bit));

 38 

 39     }

 40 

 41     int add(int i,int v)

 42 

 43     {

 44 

 45         while(i<=N)

 46 

 47         {

 48 

 49             bit[i]+=v;

 50 

 51             i+=lowbit(i);

 52 

 53         }

 54 

 55     }

 56 

 57     int sum(int i)

 58 

 59     {

 60 

 61         int s=0;

 62 

 63         while(i>0)

 64 

 65         {

 66 

 67             s+=bit[i];

 68 

 69             i-=lowbit(i);

 70 

 71         }

 72 

 73         return s;

 74 

 75     }

 76 

 77 };

 78 

 79 int main()

 80 

 81 {

 82 

 83     while(cin>>N&&N)

 84 

 85     {

 86 

 87         int a,b;

 88 

 89         BIT2 tree;

 90 

 91         for(int i=1;i<=N;i++)

 92 

 93         {

 94 

 95             scanf("%d%d",&a,&b);

 96 

 97             tree.add(a,1);

 98 

 99             tree.add(b+1,-1);

100 

101         }

102 

103         for(int i=1;i<=N;i++)

104 

105         {

106 

107             if(i!=1) cout<<" ";

108 

109             printf("%d",tree.sum(i));

110 

111         }

112 

113         cout<<endl;

114 

115     }

116 

117     return 0;

118 

119 }

120 

121  
View Code

 

 

再看一道二维的

Hdu1892

http://acm.hdu.edu.cn/showproblem.php?pid=1892

 

跟一维主要的区别

 1 void init()

 2 

 3 {

 4 

 5     for(int i=1;i<MAX;i++)

 6 

 7         for(int j=1;j<MAX;j++)

 8 

 9         {

10 

11             d[i][j]=1;

12 

13             c[i][j]=lowbit(i)*lowbit(j);

14 

15         }

16 

17 }

18 

19 int sum(int i,int j)

20 

21 {

22 

23     int tot=0;

24 

25     for(int x=i;x>0;x-=lowbit(x))

26 

27         for(int y=j;y>0;y-=lowbit(y))

28 

29         {

30 

31             tot+=c[x][y];

32 

33         }

34 

35     return tot;

36 

37 }

38 

39 void add(int i,int j,int v)

40 

41 {

42 

43     for(int x=i;x<MAX;x+=lowbit(x))

44 

45         for(int y=j;y<MAX;y+=lowbit(y))

46 

47         {

48 

49             c[x][y]+=v;

50 

51         }

52 

53 }
View Code

 

 

你可能感兴趣的:(树状数组)