这一篇我们来看树状数组的加强版线段树,树状数组能玩的线段树一样可以玩,而且能玩的更好,他们在区间求和,最大,平均
等经典的RMQ问题上有着对数时间的优越表现。
一:线段树
线段树又称"区间树”,在每个节点上保存一个区间,当然区间的划分采用折半的思想,叶子节点只保存一个值,也叫单元节点,所
以最终的构造就是一个平衡的二叉树,拥有CURD的O(lgN)的时间。
从图中我们可以清楚的看到[0-10]被划分成线段的在树中的分布情况,针对区间[0-N],最多有2N个节点,由于是平衡二叉树的形
式也可以像堆那样用数组来玩,不过更加耗费空间,为最多4N个节点,在针对RMQ的问题上,我们常常在每个节点上增加一些sum,
max,min等变量来记录求得的累加值,当然你可以理解成动态规划的思想,由于拥有logN的时间,所以在RMQ问题上比数组更加优美。
二:代码
1:在节点中定义一些附加值,方便我们处理RMQ问题。
1 #region 线段树的节点 2 /// <summary> 3 /// 线段树的节点 4 /// </summary> 5 public class Node 6 { 7 /// <summary> 8 /// 区间左端点 9 /// </summary> 10 public int left; 11 12 /// <summary> 13 /// 区间右端点 14 /// </summary> 15 public int right; 16 17 /// <summary> 18 /// 左孩子 19 /// </summary> 20 public Node leftchild; 21 22 /// <summary> 23 /// 右孩子 24 /// </summary> 25 public Node rightchild; 26 27 /// <summary> 28 /// 节点的sum值 29 /// </summary> 30 public int Sum; 31 32 /// <summary> 33 /// 节点的Min值 34 /// </summary> 35 public int Min; 36 37 /// <summary> 38 /// 节点的Max值 39 /// </summary> 40 public int Max; 41 } 42 #endregion
2:构建(Build)
前面我也说了,构建有两种方法,数组的形式或者链的形式,各有特点,我就采用后者,时间为O(N)。
1 #region 根据数组构建“线段树" 2 /// <summary> 3 /// 根据数组构建“线段树" 4 /// </summary> 5 /// <param name="length"></param> 6 public Node Build(int[] nums) 7 { 8 this.nums = nums; 9 10 return Build(nodeTree, 0, nums.Length - 1); 11 } 12 #endregion 13 14 #region 根据数组构建“线段树" 15 /// <summary> 16 /// 根据数组构建“线段树" 17 /// </summary> 18 /// <param name="left"></param> 19 /// <param name="right"></param> 20 public Node Build(Node node, int left, int right) 21 { 22 //说明已经到根了,当前当前节点的max,sum,min值(回溯时统计上一层节点区间的值) 23 if (left == right) 24 { 25 return new Node 26 { 27 left = left, 28 right = right, 29 Max = nums[left], 30 Min = nums[left], 31 Sum = nums[left] 32 }; 33 } 34 35 if (node == null) 36 node = new Node(); 37 38 node.left = left; 39 40 node.right = right; 41 42 node.leftchild = Build(node.leftchild, left, (left + right) / 2); 43 44 node.rightchild = Build(node.rightchild, (left + right) / 2 + 1, right); 45 46 //统计左右子树的值(min,max,sum) 47 node.Min = Math.Min(node.leftchild.Min, node.rightchild.Min); 48 node.Max = Math.Max(node.leftchild.Max, node.rightchild.Max); 49 node.Sum = node.leftchild.Sum + node.rightchild.Sum; 50 51 return node; 52 } 53 #endregion
3:区间查询
在线段树中,区间查询还是有点小麻烦的,存在三种情况。
① 完全包含:也就是节点的线段范围完全在查询区间的范围内,这说明我们要么到了“单元节点",要么到了一个子区间,这种情况
就是我找到了查询区间的某一个子区间,直接累积该区间值就可以了。
② 左交集: 这种情况我们需要到左子树去遍历。
③右交集: 这种情况我们需要到右子树去遍历。
比如说:我要查询Sum[4-8]的值,最终会成为:Sum总=Sum[4-4]+Sum[5-5]+Sum[6-8],时间为log(N)。
1 #region 区间查询 2 /// <summary> 3 /// 区间查询(分解) 4 /// </summary> 5 /// <returns></returns> 6 public int Query(int left, int right) 7 { 8 int sum = 0; 9 10 Query(nodeTree, left, right, ref sum); 11 12 return sum; 13 } 14 15 /// <summary> 16 /// 区间查询 17 /// </summary> 18 /// <param name="left">查询左边界</param> 19 /// <param name="right">查询右边界</param> 20 /// <param name="node">查询的节点</param> 21 /// <returns></returns> 22 public void Query(Node node, int left, int right, ref int sum) 23 { 24 //说明当前节点完全包含在查询范围内,两点:要么是单元节点,要么是子区间 25 if (left <= node.left && right >= node.right) 26 { 27 //获取当前节点的sum值 28 sum += node.Sum; 29 return; 30 } 31 else 32 { 33 //如果当前的left和right 和node的left和right无交集,此时可返回 34 if (node.left > right || node.right < left) 35 return; 36 37 //找到中间线 38 var middle = (node.left + node.right) / 2; 39 40 //左孩子有交集 41 if (left <= middle) 42 { 43 Query(node.leftchild, left, right, ref sum); 44 } 45 //右孩子有交集 46 if (right >= middle) 47 { 48 Query(node.rightchild, left, right, ref sum); 49 } 50 51 } 52 } 53 #endregion
4:更新操作
这个操作跟树状数组中的更新操作一样,当递归的找到待修改的节点后,改完其值然后在当前节点一路回溯,并且在回溯的过程中一
路修改父节点的附加值直到根节点,至此我们的操作就完成了,复杂度同样为logN。
1 #region 更新操作 2 /// <summary> 3 /// 更新操作 4 /// </summary> 5 /// <param name="index"></param> 6 /// <param name="key"></param> 7 public void Update(int index, int key) 8 { 9 Update(nodeTree, index, key); 10 } 11 12 /// <summary> 13 /// 更新操作 14 /// </summary> 15 /// <param name="index"></param> 16 /// <param name="key"></param> 17 public void Update(Node node, int index, int key) 18 { 19 if (node == null) 20 return; 21 22 //取中间值 23 var middle = (node.left + node.right) / 2; 24 25 //遍历左子树 26 if (index >= node.left && index <= middle) 27 Update(node.leftchild, index, key); 28 29 //遍历右子树 30 if (index <= node.right && index >= middle + 1) 31 Update(node.rightchild, index, key); 32 33 //在回溯的路上一路更改,复杂度为lgN 34 if (index >= node.left && index <= node.right) 35 { 36 //说明找到了节点 37 if (node.left == node.right) 38 { 39 nums[index] = key; 40 41 node.Sum = node.Max = node.Min = key; 42 } 43 else 44 { 45 //回溯时统计左右子树的值(min,max,sum) 46 node.Min = Math.Min(node.leftchild.Min, node.rightchild.Min); 47 node.Max = Math.Max(node.leftchild.Max, node.rightchild.Max); 48 node.Sum = node.leftchild.Sum + node.rightchild.Sum; 49 } 50 } 51 } 52 #endregion
最后我们做个例子,在2000000的数组空间中,寻找200-3000区间段的sum值,看看他的表现如何。
1 using System; 2 using System.Collections.Generic; 3 using System.Linq; 4 using System.Text; 5 using System.Diagnostics; 6 using System.Threading; 7 using System.IO; 8 9 namespace ConsoleApplication2 10 { 11 public class Program 12 { 13 public static void Main() 14 { 15 int[] nums = new int[200 * 10000]; 16 17 for (int i = 0; i < 10000 * 200; i++) 18 { 19 nums[i] = i; 20 } 21 22 Tree tree = new Tree(); 23 24 //将当前数组构建成 “线段树” 25 tree.Build(nums); 26 27 var watch = Stopwatch.StartNew(); 28 29 var sum = tree.Query(200, 3000); 30 31 watch.Stop(); 32 33 Console.WriteLine("耗费时间:{0}ms, 当前数组有:{1}个数字, 求出Sum=:{2}", watch.ElapsedMilliseconds, nums.Length, sum); 34 35 Console.Read(); 36 } 37 } 38 39 public class Tree 40 { 41 #region 线段树的节点 42 /// <summary> 43 /// 线段树的节点 44 /// </summary> 45 public class Node 46 { 47 /// <summary> 48 /// 区间左端点 49 /// </summary> 50 public int left; 51 52 /// <summary> 53 /// 区间右端点 54 /// </summary> 55 public int right; 56 57 /// <summary> 58 /// 左孩子 59 /// </summary> 60 public Node leftchild; 61 62 /// <summary> 63 /// 右孩子 64 /// </summary> 65 public Node rightchild; 66 67 /// <summary> 68 /// 节点的sum值 69 /// </summary> 70 public int Sum; 71 72 /// <summary> 73 /// 节点的Min值 74 /// </summary> 75 public int Min; 76 77 /// <summary> 78 /// 节点的Max值 79 /// </summary> 80 public int Max; 81 } 82 #endregion 83 84 Node nodeTree = new Node(); 85 86 int[] nums; 87 88 #region 根据数组构建“线段树" 89 /// <summary> 90 /// 根据数组构建“线段树" 91 /// </summary> 92 /// <param name="length"></param> 93 public Node Build(int[] nums) 94 { 95 this.nums = nums; 96 97 return Build(nodeTree, 0, nums.Length - 1); 98 } 99 #endregion 100 101 #region 根据数组构建“线段树" 102 /// <summary> 103 /// 根据数组构建“线段树" 104 /// </summary> 105 /// <param name="left"></param> 106 /// <param name="right"></param> 107 public Node Build(Node node, int left, int right) 108 { 109 //说明已经到根了,当前当前节点的max,sum,min值(回溯时统计上一层节点区间的值) 110 if (left == right) 111 { 112 return new Node 113 { 114 left = left, 115 right = right, 116 Max = nums[left], 117 Min = nums[left], 118 Sum = nums[left] 119 }; 120 } 121 122 if (node == null) 123 node = new Node(); 124 125 node.left = left; 126 127 node.right = right; 128 129 node.leftchild = Build(node.leftchild, left, (left + right) / 2); 130 131 node.rightchild = Build(node.rightchild, (left + right) / 2 + 1, right); 132 133 //统计左右子树的值(min,max,sum) 134 node.Min = Math.Min(node.leftchild.Min, node.rightchild.Min); 135 node.Max = Math.Max(node.leftchild.Max, node.rightchild.Max); 136 node.Sum = node.leftchild.Sum + node.rightchild.Sum; 137 138 return node; 139 } 140 #endregion 141 142 #region 区间查询 143 /// <summary> 144 /// 区间查询(分解) 145 /// </summary> 146 /// <returns></returns> 147 public int Query(int left, int right) 148 { 149 int sum = 0; 150 151 Query(nodeTree, left, right, ref sum); 152 153 return sum; 154 } 155 156 /// <summary> 157 /// 区间查询 158 /// </summary> 159 /// <param name="left">查询左边界</param> 160 /// <param name="right">查询右边界</param> 161 /// <param name="node">查询的节点</param> 162 /// <returns></returns> 163 public void Query(Node node, int left, int right, ref int sum) 164 { 165 //说明当前节点完全包含在查询范围内,两点:要么是单元节点,要么是子区间 166 if (left <= node.left && right >= node.right) 167 { 168 //获取当前节点的sum值 169 sum += node.Sum; 170 return; 171 } 172 else 173 { 174 //如果当前的left和right 和node的left和right无交集,此时可返回 175 if (node.left > right || node.right < left) 176 return; 177 178 //找到中间线 179 var middle = (node.left + node.right) / 2; 180 181 //左孩子有交集 182 if (left <= middle) 183 { 184 Query(node.leftchild, left, right, ref sum); 185 } 186 //右孩子有交集 187 if (right >= middle) 188 { 189 Query(node.rightchild, left, right, ref sum); 190 } 191 192 } 193 } 194 #endregion 195 196 #region 更新操作 197 /// <summary> 198 /// 更新操作 199 /// </summary> 200 /// <param name="index"></param> 201 /// <param name="key"></param> 202 public void Update(int index, int key) 203 { 204 Update(nodeTree, index, key); 205 } 206 207 /// <summary> 208 /// 更新操作 209 /// </summary> 210 /// <param name="index"></param> 211 /// <param name="key"></param> 212 public void Update(Node node, int index, int key) 213 { 214 if (node == null) 215 return; 216 217 //取中间值 218 var middle = (node.left + node.right) / 2; 219 220 //遍历左子树 221 if (index >= node.left && index <= middle) 222 Update(node.leftchild, index, key); 223 224 //遍历右子树 225 if (index <= node.right && index >= middle + 1) 226 Update(node.rightchild, index, key); 227 228 //在回溯的路上一路更改,复杂度为lgN 229 if (index >= node.left && index <= node.right) 230 { 231 //说明找到了节点 232 if (node.left == node.right) 233 { 234 nums[index] = key; 235 236 node.Sum = node.Max = node.Min = key; 237 } 238 else 239 { 240 //回溯时统计左右子树的值(min,max,sum) 241 node.Min = Math.Min(node.leftchild.Min, node.rightchild.Min); 242 node.Max = Math.Max(node.leftchild.Max, node.rightchild.Max); 243 node.Sum = node.leftchild.Sum + node.rightchild.Sum; 244 } 245 } 246 } 247 #endregion 248 } 249 }