定义: lowbit(x)=x&(-x)。
那么这个式子是什么意思呢?先来看-x从二进制的角度发生了什么。我们在计组中了解过,整数在计算机一般通过补码储存,并且一个补码表示的整数x变成其相反数-x的过程相当于把x的二进制的每一位都取反,然后末尾加1。而这等价于直接把x的二进制最右边的1的左边每一位都取反。例子如下:
x | 0000001101001 100 |
---|---|
-x | 1111110010110 100 |
x&(-x) | 0000000000000 000 |
对x=(0000001101001100)2来说,最右边的1是在2号位,因此把它左边的所有位全部取反。通过-x就容易推导出lowbit(x)=x&(-x)就是取x的二进制最右边的1和它右边所有0,因此它一定是2的幂次,如1、2、4、8等。例如对x=6=(110)2来说,x&(-x)=(010)2。即lowbit(x)也可以理解为能整除x的最大2的幂次。 |
先来看一个问题:给出一个整数序列A,元素个数为N(N≤105),接下来查询K(K≤105)次,每次查询将给出一个正整数x(x≤N),求前x个整数和。
对于这个问题一般做法就是开一个sum数组,其中sum[i]表示前i个整数之和(数组下标从1开始),这样sum数组就可以在输入N个整数时就预处理出来。接着每次查询前x个整数之和,输出时,输出sum[x]即可。
现在升级一下问题,假设在查询的过程中可能随时给第x个整数加上一个整数v,要求在后才想你中能实时输出前x个整数之和。
对于这个问题,如果还是之前的做法,虽然单次查询的时间复杂度仍为O(1),但在进行更新时却需要给sum[x],sum[x+1]……sum[N]都加上整数v,这使得单次更新的时间复杂度为O(N),那么如果K次操作中大部分都是更新操作,操作的总复杂度就会使O(KN),显然无法承受。那要怎么办呢?
当当当!这边引入了我们本文的核心树状数组(BIT)。它其实仍然是一个数组,并且与sum数组类似,是一个用来记录和的数组,只不过它存放的不是前i个整数之和,而是在i号位之前(含i号位)lowbit(i)之和。
数组A使原始数组,有A[1]~A[16]共16个元素;数组C是树状数组,其中C[i]存放数组A中i号位之前lowbit(i)个元素之和(到这里各位可以淡化二进制的概念,不必过分关心)
C[i]的覆盖长度是lowbit(i),它是2的幂次,即1、2、4、8等。
接下来思考一下,在这样的定义下,怎样解决下面两个问题:
先来看第一个问题,如何设计函数getSum(x),返回前x个数之和。
假设想要查询A[1]+…+A[14],那么从树状数组的定义出发,它实际是什么东西呢?我们很容易发现A[1]+…+A[14]=C[8]+C[12]+C[14],又比如要查询A[1]+…A[11],从图中同样可以得到A[1]+…A[11]=C[8]+C[10]+C[11]。那么怎样知道A[1]+…+A[x]对应的是树状数组中的哪些项?可通过如下方法:
记SUM(1,x)=A[1]+…+A[x],由于C[x]的覆盖长度为lowbit(x),因此C[x]=A[x-lowbit(x)+1]+…A[x]
于是马上可以得到
SUM(1,x)=A[1]+…+A[x]
=A[1]+…A[x-lowbit(x)]+A[x-lowbit(x)+1]+…+A[x]
=SUM(1,x-lowbit(x))+C[x]
这样就把SUM(1,x)转换为SUM(1,x-lowbit(x))了
下面给出getsum函数:
//getSum函数返回前x个整数之和
int getSum(int x){
int sum=0;
for(int i=x;i>0;i-=lowbit(i)){//注意是i>0而不是i>=0
sum+=c[i];//累计c[i],然后把问题缩小为SUM(1,i-lowbit(i))
}
return sum;//返回和
}
结合上面几个图就会发现,getSum函数的过程实际上是在沿着一条不断左上的路径行进。另外如果要求数组下标在区间[x,y]内的数之和,即A[x]+A[x+1]+……+A[y],可以转换成getSum(y)-getSum(x-1)来解决。
下面来设计第二个问题,如何设计update(x,v),实现将第x个数加上一个数v的功能。
要让A[x]加上v,就是要寻找树状数组c中能覆盖A[x]的那些元素,让它们都加上v,只要总是寻找离当前的“矩形”C[x]最近的“矩形”C[y],使得C[y]能狗覆盖C[x]即可。
那么,如何找到呢?问题等价于求一个尽可能小的整数a,使得lowbit(y)必须大于lowbit(x)。显然,由于lowbit(x)是取x的二进制最右边的1的位置,因此如果lowbit(a)
于是update函数的做法就很明确了,只要让x不断加上lowbit(x),并让每步的C[x]都加上v,直到x超过给定的数据范围为止
//update函数将第x个整数加上v
void update(int x,int y){
for(int i=x;i<=N;i+=lowbit(i)){//注意i必须能取到N
c[i]+=v;//让c[i]加上v,然后让c[i+lowbit(i)]加上v
}
}
这便是树状数组最核心的两个代码思想了,下面看一个经典问题
给定一个有N个正整数的序列A(<=105,A[i]<=105),对序列中的每个数,求出序列中它左边比它小的数的个数。
#include
using namespace std;
const int maxn=100010;
#define lowbit(i) ((i)&(-i)) //lowbit写成宏定义的形式
int c[maxn]; //树状数组
//update函数将第x个整数加上v
void update(int x,int v){
for(int i=x;i0;i-=lowbit(i)){ //注意是i>0而不是i>=0
sum+=c[i];//累计c[i],然后把问题缩小为SUM(1,i-lowbit(i))
}
return sum;//返回和
}
int main(){
int n,x;
cin>>n;
memset(c,0,sizeof(c));//树状数组初值为0
for(int i=0;i>x;
update(x,i);//x的出现次数加1
}
return 0;
}