线段树 | Segment Tree

在计算机科学中,Segment Tree也称为统计树,是一种树数据结构,用于存储有关区间或段的信息。它允许查询哪些存储的段包含给定点。原则上,它是一个静态结构;也就是说,它是一种一旦建成就无法修改的结构,一个类似的数据结构是区间树。

 

一个集合 I 的 n 个间隔的线段树使用 O(n log n)的空间 存储,并且可以在 O(n log n) 时间内构建。线段树支持在O(log n + k) 的时间内搜索包含某个点的所有区间,k 是检索到的区间或段的数量。

要理解线段树,我们先来考虑下面的问题:

我们有一个数组 arr[0 . . . n-1]。我们应该能够  

  • 求从索引 l 到 r 的元素之和,其中 0 <= l <= r <= n-1
  • 将数组的指定元素的值更改为新值x。我们需要做 arr[i] = x 其中 0 <= i <= n-1。

最简单的解决上述问题的方法是常规遍历,但其更新和求和的时间复杂度都是O(n),这显然不是最优的。另一种方法是Prefix Sum的方法,可以在O(1)的时间内得到一个区间的和,但更新操作的时间仍然是O(n)。

想要让更新和求和两个操作的时间复杂度都保持在O(logN),Segment Tree是一个很好的方法。

Segment Tree的表示

  • 叶节点是输入数组的元素。 
  • 每个内部节点代表叶节点的一些合并。对于不同的问题,合并可能会有所不同。大多数情况下,合并是一个节点下的叶子节点的总和。
  • 树的数组表示用于表示段树。对于索引 i 处的每个节点,左子节点位于索引(2*i+1)处,右子节点位于(2*i+2)处,父节点位于  (⌊(i – 1) / 2⌋) 处(和普通的二叉树一样)。

 简单的Segment Tree

线段树 | Segment Tree_第1张图片

从给定数组中构建一个线段树 

我们从段arr[0 开始。. . n-1]。并且每次我们将当前段一分为二(如果它还没有变成长度为 1 的段),然后在分割后的两段上执行相同的操作,对于每个这样的段,我们将其对应的元素总和存储在相应的节点中。 
所构建的段树的所有级别都将被完全填充,除了最后一层。此外,这棵树将是一个完整的二叉树,因为我们总是在每一层将段分成两部分。由于构建的树始终是具有 n 个叶子的完整二​​叉树,因此将有n-1 个内部节点。所以节点的总数将是2*n – 1。

注意:线段树的结果本质上是一个完整的二叉树,二叉树的节点上存储着我们想要的值,因此线段树的高度就是log₂N,由于树是使用数组表示的,并且必须维护父索引和子索引之间的关系,因此为段树分配的内存大小将为(2 * 2 ⌈log 2 n⌉  – 1)。

查询给定区间范围的和

下面是获取元素总和的算法

线段树 | Segment Tree_第2张图片

在上面的实现中,我们需要考虑三种情况

  • 如果遍历树时当前节点的范围不在给定范围内,则不会在 ans 中添加该节点的值
  • 如果节点范围与给定范围部分重叠,则根据重叠向左或向右移动
  • 如果范围与给定范围完全重叠,则将其添加到 ans

更新操作

线段树构造和查询操作一样,更新也可以递归完成。我们得到了一个需要更新的索引。设diff为要添加的值。我们从线段树的根开始,将diff添加到在其范围内具有给定索引的所有节点。如果一个节点的范围不包含给定的索引,我们不会对该节点进行任何更改。 

代码演示(C++)


#include 
using namespace std;

// A utility function to get the middle index from corner indexes.
int getMid(int s, int e) { return s + (e -s)/2; }



int getSumUtil(int *st, int ss, int se, int qs, int qe, int si)
{
	// If segment of this node is a part of given range, then return
	// the sum of the segment
	if (qs <= ss && qe >= se)
		return st[si];

	// If segment of this node is outside the given range
	if (se < qs || ss > qe)
		return 0;

	// If a part of this segment overlaps with the given range
	int mid = getMid(ss, se);
	return getSumUtil(st, ss, mid, qs, qe, 2*si+1) +
		getSumUtil(st, mid+1, se, qs, qe, 2*si+2);
}

/* A recursive function to update the nodes which have the given */
void updateValueUtil(int *st, int ss, int se, int i, int diff, int si)
{
	// Base Case: If the input index lies outside the range of
	// this segment
	if (i < ss || i > se)
		return;

	// If the input index is in range of this node, then update
	// the value of the node and its children
	st[si] = st[si] + diff;
	if (se != ss)
	{
		int mid = getMid(ss, se);
		updateValueUtil(st, ss, mid, i, diff, 2*si + 1);
		updateValueUtil(st, mid+1, se, i, diff, 2*si + 2);
	}
}

// The function to update a value in input array and segment tree.
// It uses updateValueUtil() to update the value in segment tree
void updateValue(int arr[], int *st, int n, int i, int new_val)
{
	// Check for erroneous input index
	if (i < 0 || i > n-1)
	{
		cout<<"Invalid Input";
		return;
	}

	// Get the difference between new value and old value
	int diff = new_val - arr[i];

	// Update the value in array
	arr[i] = new_val;

	// Update the values of nodes in segment tree
	updateValueUtil(st, 0, n-1, i, diff, 0);
}

// Return sum of elements in range from index qs (query start)
// to qe (query end). It mainly uses getSumUtil()
int getSum(int *st, int n, int qs, int qe)
{
	// Check for erroneous input values
	if (qs < 0 || qe > n-1 || qs > qe)
	{
		cout<<"Invalid Input";
		return -1;
	}

	return getSumUtil(st, 0, n-1, qs, qe, 0);
}

// A recursive function that constructs Segment Tree for array[ss..se].
// si is index of current node in segment tree st
int constructSTUtil(int arr[], int ss, int se, int *st, int si)
{
	// If there is one element in array, store it in current node of
	// segment tree and return
	if (ss == se)
	{
		st[si] = arr[ss];
		return arr[ss];
	}

	// If there are more than one elements, then recur for left and
	// right subtrees and store the sum of values in this node
	int mid = getMid(ss, se);
	st[si] = constructSTUtil(arr, ss, mid, st, si*2+1) +
			constructSTUtil(arr, mid+1, se, st, si*2+2);
	return st[si];
}

/* Function to construct segment tree from given array. This function
allocates memory for segment tree and calls constructSTUtil() to
fill the allocated memory */
int *constructST(int arr[], int n)
{
	// Allocate memory for the segment tree

	//Height of segment tree
	int x = (int)(ceil(log2(n)));

	//Maximum size of segment tree
	int max_size = 2*(int)pow(2, x) - 1;

	// Allocate memory
	int *st = new int[max_size];

	// Fill the allocated memory st
	constructSTUtil(arr, 0, n-1, st, 0);

	// Return the constructed segment tree
	return st;
}

// Driver program to test above functions
int main()
{
	int arr[] = {1, 3, 5, 7, 9, 11};
	int n = sizeof(arr)/sizeof(arr[0]);

	// Build segment tree from given array
	int *st = constructST(arr, n);

	// Print sum of values in array from index 1 to 3
	cout<<"Sum of values in given range = "<除了上面所说的基本操作外,Segment Tree还可以做范围最小查询,相关的具体算法,我们会在以后的文章中讨论。

 其他语言实现下载链接:

(包含各种语言:C语言、Python、Java、C++等均有示例)

免费​资源下载:Segment Tree

你可能感兴趣的:(java,数据结构,算法,leetcode)