维度探索——二维线段树

前言

线段树是一个神奇的东西,可以O(n)建树,O(logn)修改、查询,维护一个区间的性质。但是线段树维护的序列一定是一维的,如果我要维护一个“二维”的结构呢?就比如说,维护一个矩阵中子矩阵的和。简单地说就是给你一个表格,每次用“圈出”一个矩形的部分让你求它所有元素的和。

维度探索——二维线段树_第1张图片

没有学过线段树的同学们一定要先学一下线段树一定要先学习一下,再来看这篇博客。

学习链接: 我与线段树的故事(纯新手请进)

1.静态二维子矩阵和

解决这个问题自然要从静态(一维)区间和得出灵感。静态一维区间和我们用的方法是求“前缀和”,我们可以用O(n)的时间复杂度求出一个pre数组,pre[i]表示闭区间[1,i]对应元素的和。如果我想要求区间[i,j]的区间和,用计算pre[j]-pre[i-1]就可以O(1)解决问题。

同学们可以尝试着把这个理论推广到二维,我们可以去维护一个二维“前缀和”来解决这个问题。

维度探索——二维线段树_第2张图片

区域“I”的矩阵和,就相当于是区域“I+II+III+IV”的和减去区域“III+IV”、减去区域“III+II”、再加上区域“III”。

如果我们用sum(i,j,k,l)表示区域“I”,那么就有sum(i,j,k,l)=sum(1,j,1,l)-sum(1,i,1,l)-sum(1,j,1,k)+sum(1,i,1,k)。这样我们就把所有的数据表示成了一个“二维前缀”的形式了。我们可以用pre(i,j)表示sum(1,i,1,j),就有sum(i,j,k,l)=pre(j,l)-pre(i,l)-pre(j,k)+pre(i,k)。

pre(i,j)如何求解呢?显然可以使用和sum同样的方法:pre[i][j]=pre[i-1][j]+pre[i][j-1]-pre[i-1][j-1]+a[i][j](a表示原数组)。

请看伪代码:

Init:
For i = 1 to n
    For j = 1 to m
        pre[i][j]=pre[i-1][j]+pre[i][j-1]-pre[i-1][j-1]+a[i][j]

Query:
sum(i,j,k,l)=pre[j][l]-pre[i][l]-pre[j][k]+pre[i][k]

2.二维线段树

二维线段树,每一个节点对应一个子矩阵(根节点代表整体),每个节点有四个儿子节点,分别表示它的“左上,左下,右上,右下”四个部分,例如下图:

维度探索——二维线段树_第3张图片

当然,如果边长不是二的整数次幂也是可以这样二分的:

维度探索——二维线段树_第4张图片

这样我们就得到了一个建树的方法:

build(Root,Left,Right,Up,Down)
    if 确定到唯一元素
        Root.sum=这个元素
    else
    if 这个区域只有一列
        Root.左上子=新结点
        Root.左下子=新结点

        int mid=(Up+Down)/2
        build(Root.左上子,Left,Right,Up,mid)
        build(Root.左下子,Left,Right,mid+1,Down)

        Root.sum=Root.左上子.sum+Root.左下子.sum
    else
    if 这个区间只有一行
        Root.左上子=新结点
        Root.右上子=新结点

        int mid=(Left+Right)/2
        build(Root.左上子,Left,mid,Up,Down)
        build(Root.右上子,mid+1,Right,Up,Down)

        Root.sum=Root.左上子.sum+Root.右上子.sum
    else
        Root.左上子=新结点
        Root.左下子=新结点
        Root.右上子=新结点
        Root.右下子=新结点

        int midLR=(Left+Right)/2
        int midUD=(Up+Down)/2
        build(Root.左上子,Left,midLR,Up,midUD)
        build(Root.左下子,Left,midLR,midUD+1,Down)
        build(Root.右上子,midLR+1,Right,Up,midUD)
        build(Root.右下子,midLR+1,Right,MidUD+1,Down)

        Root.sum=Root.左上子.sum+Root.左下子.sum+Root.右上子.sum+Root.右下子.sum

查询还是比较简单的(因为没有lazy,不懂的回去复习线段树!):

Query(Root,Left,Right,Up,Down)
    if Root == NULL//这样当查询NULL结点时可以直接忽略掉,不会RE
        return 0
    return Query(Root.左上子,Left,Right,Up,Down)+
        Query(Root.左下子,Left,Right,Up,Down)+
        Query(Root.右上子,Left,Right,Up,Down)+
        Query(Root.右下子,Left,Right,Up,Down)

然后是正经的代码:

struct NODE
{
    int l,r,u,d;
    int luch,ruch,ldch,rdch;
    int sum;
    NODE(int L=0,int R=0,int U=0,int D=0,
        int LUCH=0,int RUCH=0,int LDCH=0,int RDCH=0,
        int SUM=0){
        l=L;r=R;u=U;d=D;
        luch=LUCH;ruch=RUCH;ldch=LDCH;rdch=RDCH;
        sum=SUM;
    }
}ns[1048576];

int newnode=1;//当前亟待申请的节点

#define LST(ROOT) (ns[ROOT].l)
#define RST(ROOT) (ns[ROOT].r)
#define UST(ROOT) (ns[ROOT].u)
#define DST(ROOT) (ns[ROOT].d)//表示一个结点的区间范围

#define LUCH(ROOT) (ns[ROOT].luch)
#define RUCH(ROOT) (ns[ROOT].ruch)
#define LDCH(ROOT) (ns[ROOT].ldch)
#define RDCH(ROOT) (ns[ROOT].rdch)//表示一个结点的四个儿子

#define SUM(ROOT) (ns[ROOT].sum)//这些define可以使代码更好理解,但实际上没什么必要

int a[101][101];

void build(int root,int l,int r,int u,int d)
{
    if(l==r && u==d)
        ns[root]=NODE(l,r,u,d,-1,-1,-1,-1,a[u][l]);
    else
    if(u==d)
    {
        int nlu=newnode++;
        int nru=newnode++;
        int mid=(l+r)/2;
        build(nlu,l,mid,u,d);
        build(nru,mid+1,r,u,d);
        ns[root]=NODE(l,r,u,d,nlu,nru,-1,-1,SUM(nlu)+SUM(nru));
    }else
    if(l==r)
    {
        int nlu=newnode++;
        int nld=newnode++;
        int mid=(u+d)/2;
        build(nlu,l,r,u,mid);
        build(nld,l,r,mid+1,d);
        ns[root]=NODE(l,r,u,d,nlu,-1,nld,-1,SUM(nlu)+SUM(nld));
    }else{
        int nlu=newnode++;
        int nru=newnode++;
        int nld=newnode++;
        int nrd=newnode++;
        int midlr=(l+r)/2;
        int midud=(u+d)/2;
        build(nlu,l,midlr,u,midud);
        build(nru,midlr+1,r,u,midud);
        build(nld,l,midlr,midud+1,d);
        build(nrd,midlr+1,r,midud+1,d);
        ns[root]=NODE(l,r,u,d,nlu,nru,nld,nrd,SUM(nlu)+SUM(nru)+SUM(nld)+SUM(nrd));
    }
}

int ask(int root,int l,int r,int u,int d)
{
    if(root==-1)
        return 0;
    if( (l<=LST(root) && RST(root)<=r) &&
        (u<=UST(root) && DST(root)<=d))
        return SUM(root);
    if( (LST(root)>r || RST(root)d || DST(root)return 0;
    int nlu=LUCH(root);
    int nru=RUCH(root);
    int nld=LDCH(root);
    int nrd=RDCH(root);
    return ask(nlu,l,r,u,d)+ask(nru,l,r,u,d)+ask(nld,l,r,u,d)+ask(nrd,l,r,u,d);
}

理论上来讲,代码与普通线段树是极其相似的。

后记

赶稿匆忙,如有谬误,望同学们谅解。

你可能感兴趣的:(数据结构,算法导论)