树状数组原理解析

我们知道,对长度为n的数组,如果我们要改变其中某个值,则时间复杂度为O(1)。如果要求出S[m]=a[1]+a[2]+.....+a[m],则需要O(m)的时间复杂度。若我们一边修改数组的值,一边要求求出其部分和S[m],使用一般的方法,时间复杂度是O(m*n)的,若m和查找次数n很大,那么该算法将不可取。

为了解决这个问题,出现了树状数组这一数据结构。它可以以O(log n)的时间复杂度修改数组中的值,同时以O(log n)的时间复杂度求部分和。这里的log是以2为底的。


1、基本结构

树状数组可以简单地用下面的图理解:

树状数组原理解析_第1张图片

C[1]=A[1]

C[2]=A[1]+A[2]

C[3]=A[3]

C[4]=A[1]+A[2]+A[3]+A[4]

……

可以看到,C[1]“管辖”一个元素,而C[2],C[6]“管辖”2个元素,C[8]“管辖”8个元素。那么C[N]“管辖”多少个元素呢?在树状数组中,这是个重要的概念。我们规定,把N转换为二进制,最右侧0的个数为k,则它管辖的元素个数就是2^k个。如2的二进制位10,2^1=2。4的二进制为100,2^2=4。定义函数lowbit(N)=2^k。那么C[N]管辖了lowbit[N]个元素。根据这样的定义,你就可以自己画出更大范围的树状数组了。


2、树状数组的求和过程及其原理

假设C数组已经初始化好了。如何利用它来求和呢?

举个例子。

我们要求S[6]=A[1]+A[2]+A[3]+A[4]+A[5]+A[6]

先令S=C[6]=A[5]+A[6],注意到lowbit(6)=2,现在的S中只有两个数的和。所以我们还求6-lowbit(6)=4个数的和。这时找到C[4],S=S+C[4],lowbit(4)=4。这时我们已经求了2+4=6个数的和了,所以S就是答案。

可能上面的过程表达的还是不太清楚。我们用一个程序段来更好地表达这个过程:

int sum(int k){
	int s=0;
	for (int i=k;i>=1;i-=lowbit(i))
		s+=c[i];
	return s;
}

在上述程序段中,我们要求S[k],先让i=k;s+=c[i],之后每次让i=i-lowbit(i),s+=c[i],直到到最开头做完为止。

为什么可以这样做?且为什么这样做的复杂度是log级别的?下面再深入解释下原理:

注意观察,其实每次每次让i=i-lowbit(i),都减去了i二进制中最右边的一个1!而i的二进制中最多有log i个1,因此时间复杂度是log级的。

这样做的正确性在于:把求S[k]转换成求几段和的累加,而“分段”是以k的二进制中的1来决定的。

听起来有点拗口,我们还是用一个具体的例子解释。

比如求S[105],105的二进制是1101001。首先S+=C[105],注意到lowbit(105)=1,这样就减去了1101001中最右边的1。1101001-1=1101000。1101000是104的二进制。下面S+=C[104],lowbit(104)=8,104-8=96。96的二进制是1100000。这就减去了右边第二个1。下面S+=C[96],lowbit(96)=32,96-32=64.64的二进制是1000000,这就减去了右边第三个1。最后S+=C[32]结束。

整个求和过程i的变化就是1101001->1101000->1100000->1000000->0.


3、树状数组的初始化

初始化C数组有两种方法。

一种是利用sum函数,C[N]=A[N]+Sum(N-1)-Sum(N-lowbit(N)),这很好理解。

还有一种是利用以下的程序段:

void init(){
	for (int i=1;i<=n;i++){
		c[i]=a[i];
		for(int j=i-1;j>i-lowbit(i);j-=lowbit(j))
			c[i]+=c[j];//注意,这里是加c[j]而不是a[j]
	}
}

这个程序段和上面的思路类似,读者可自行分析其正确性。


4、树状数组对值的修改。

若修改数组中的一个值,则C数组需要做相应修改。我们每次往上找父节点就行了,时间复杂度也是log级的。

参考程序段:

void modify(int x,int v){
	for (int i=x;i<=n;i+=lowbit(i))
		c[i]+=v;
}

5、其他

注意写树状数组的时候,最好把数组下标定为1..N,而不是0..N-1,这样有利于编程方便!


6、参考程序

#include
#include

int n,a[10001],c[10001],m,t,x,y,v;

int lowbit(int x){
	return x & (-x);
}

int sum(int k){
	int s=0;
	for (int i=k;i>=1;i-=lowbit(i))
		s+=c[i];
	return s;
}

void modify(int x,int v){
	for (int i=x;i<=n;i+=lowbit(i))
		c[i]+=v;
}

void init(){
	for (int i=1;i<=n;i++){
		c[i]=a[i];
		for(int j=i-1;j>i-lowbit(i);j-=lowbit(j))
			c[i]+=c[j];//注意,这里是加c[j]而不是a[j]
	}
}

int main(){
	scanf("%d%d",&n,&m);
	a[0]=0;c[0]=0;
	for(int i=1;i<=n;i++)
		scanf("%d",&a[i]);
	init();

	for (int i=1;i<=m;i++){
		scanf("%d",&t);
		if (t==1){
			scanf("%d%d",&x,&v);
			modify(x,v);
		}else{
			scanf("%d%d",&x,&y);
			printf("%d\n",sum(y)-sum(x-1));
		}
	}

	return 0;
}


你可能感兴趣的:(算法复习)