树状数组(BIT)是一个查询和修改复杂度都为log(n)的数据结构,主要用于查询任意两位之间的所有元素之和,其编程简单,很容易被实现。而且可以很容易地扩展到二维。让我们来看一道很裸的二维树状数组题:
在一个“打鼹鼠”的游戏中,鼹鼠会不时地从洞中钻出来,不过不会从洞口钻进去(鼹鼠真胆大……)。洞口都在一个大小为n(n<=1024)的正方形中。这个正方形在一个平面直角坐标系中,左下角为(0,0),右上角为(n-1,n-1)。洞口所在的位置都是整点,就是横纵坐标都为整数的点。而SuperBrother也不时地会想知道某一个范围的鼹鼠总数。这就是你的任务。
每个输入文件有多行。
第一行,一个数n,表示鼹鼠的范围。
以后每一行开头都有一个数m,表示不同的操作:
m=1,那么后面跟着3个数x,y,k(0<=x,y<n),表示在点(x,y)处新出现了k只鼹鼠;
m=2,那么后面跟着4个数x1,y1,x2,y2(0<=x1<=x2<n,0<=y1<=y2<n),表示询问矩形(x1,y1)-(x2,y2)内的鼹鼠数量;
m=3,表示老师来了,不能玩了。保证这个数会在输入的最后一行。询问数不会超过10000,鼹鼠数不会超过maxlongint。
把这个问题简单抽象一下。就是:
这个问题,输入数据规模可能很大(虽然实际上并非这样,出题人懒到随机生成数据)。而且是动态修改,显然不能使用前缀和,所以,树状数组就是我们的首选。而然,树状数组本来是一维的,如何把它推广到二维去呢。其实很简单,其方法类似与先生成没一行原数组的一维树状数组,再把一个个一维树状数组组合成二维的其对应关系为:
C[1][1]=a[1][1],C[1][2]=a[1][1]+a[1][2],C[1][3]=a[1][3],C[1][4]=a[1][1]+a[1][2]+a[1][3]+a[1][4],c[1][5]=a[1][5],C[1][6]=a[1][5]+a[1][6],...
C[2][1]=a[1][1]+a[2][1],C[2][2]=a[1][1]+a[1][2]+a[2][1]+a[2][2],C[2][3]=a[1][3]+a[2][3],C[2][4]=a[1][1]+a[1][2]+a[1][3]+a[1][4]+a[2][1]+a[2][2]+a[2][3]+a[2][4], C[2][5]=a[1][5]+a[2][5],C[2][6]=a[1][5]+a[1][6]+a[2][5]+a[2][6],...
C[3][1]=a[3][1],C[3][2]=a[3][1]+a[3][2],C[3][3]=a[3][3],C[3][4]=a[3][1]+a[3][2]+a[3][3]+a[3][4],C[3][5]=a[3][5],C[3][6]=a[3][5]+a[3][6],...
C[4][1]=a[1][1]+a[2][1]+a[3][1]+a[4][1],C[4][2]=a[1][1]+a[1][2]+a[2][1]+a[2][2]+a[3][1]+a[3][2]+a[4][1]+a[4][2],C[4][3]=a[1][3]+a[2][3]+a[3][3]+a[4][3],...(太多了,我就写到3吧)
……
通过观察能发现,第一行是本身,第二行是第一行加上其本身,第三行是本身,第四行是第一、二行加上其本身。这和一维的树状数组是一摸一样的。所以,我们很容易就可以写出修改、查询 函数。
void modfily(int x,int y,int data){ x+=1;y+=1; for (int i=x;i<=n;i+=lowbit(i)) for (int j=y;j<=n;j+=lowbit(j)) c[i][j]+=data; } int sum(int x,int y){ x+=1;y+=1; int result=0; for (int i=x;i>0;i-=lowbit(i)) for (int j=y;j>0;j-=lowbit(j)) result+=c[i][j]; return result; }
这就是二维树状数组的核心了,代码和一维的相仿,异常简单。大家可能有疑问,为什么i,j要加1。其实这是我被坑以后的领悟,如果i,j=0的话,lowbit也永远为0,程序会陷入死循环,直接TLE。从这两个函数我们也可以看出,二维树状数组的查询、修改时间复杂度为log(n)²。
数据结构部分解决了,怎么求一个子矩阵中的数的和呢?画张图我们就可以求出,公式为:sum(x2, y2) - sum(x1-1, y2) - sum(x2, y1-1) + sum(x1-1, y1-1)
顺便附上这道例题的AC code:
#include <cstdio> #include <cstdlib> using namespace std; int m=0,n,x,y,k,x1,y1,c[1026][1026]; void modfily(int,int,int); int sum(int,int); inline int lowbit(int x){ return x&(-x); } int main(void){ scanf("%d",&n); for (int i=0;i<n;++i) for (int j=0;j<n;++j) c[i][j]=0; while (m!=3){ scanf("%d",&m); if (m==1){ scanf("%d%d%d",&x,&y,&k); modfily(x,y,k); } if (m==2){ scanf("%d%d%d%d",&x,&y,&x1,&y1); printf("%d\n",abs(sum(x1,y1)-sum(x-1,y1)-sum(x1,y-1)+sum(x-1,y-1))); } } return 0; } void modfily(int x,int y,int data){ x+=1;y+=1; for (int i=x;i<=n;i+=lowbit(i)) for (int j=y;j<=n;j+=lowbit(j)) c[i][j]+=data; } int sum(int x,int y){ x+=1;y+=1; int result=0; for (int i=x;i>0;i-=lowbit(i)) for (int j=y;j>0;j-=lowbit(j)) result+=c[i][j]; return result; }