百度百科:划分树是一种基于线段树的数据结构。主要用于快速求出(在log(n)的时间复杂度内)序列区间的第k大值。
划分树的基本思想就是对于某个区间,把它划分成两个子区间,左边区间的数小于右边区间的数。查找的时候通过记录进入左子树的数的个数,确定下一个查找区间,最后范围缩小到1,就找到了。
建树
划分树建树的时间复杂度和线段树不同,划分树是O(nlogn),划分树的建树依赖一个排好序的数组来辅助建树。
先将读入的数据排序,然后,对于l,r区间,取他们的中间值sorted[mid],然后依次扫l~r区间,将每个节点划分到儿子(即l~mid和mid+1~r)中。注意,这里面分到每个子树的节点是相对有序的,即对于分到每一颗子树里面的数,不改变它们以前的相对顺序。另外,在进行这个过程的时候,记录一个类似前缀和的东西,即l到i这个区间内有多少节点划分到左子树。
画了一颗划分树对数列[1 5 2 3 6 4 7 3 0 0]进行划分,下图有助于理解(红色表示该数被分到左儿子)//摘自小HH's Blog
下面给出建树的具体实现代码并具体解释(这里只是个人的理解),首先给出划分树的数据结构
struct Node { int l,r; int mid() { return (l+r)>>1; } } tree[N<<2]; int sorted[N]; int val[20][N],toLeft[20][N];sorted[]是对输入的数据进行排序后存放的数组,目的是用来辅助建树。
val[][]有人会奇怪这里为什么用二维数组,第一维是用来记录到第几层,第二维是用来记录这层上的所有数
toLeft[][]是用来记录区间[tree[rt].l , l - 1]有多少个数划分到了左边
如果还是不明白具体看下面代码
void build(int l,int r,int rt,int deep)//复杂度nlogn { tree[rt].l = l;//记录每个节点的左右两个端点 tree[rt].r = r; if(l == r) return; int m = tree[rt].mid(); int midval = sorted[m];//取出每条线段上所有数的中值 int leftsame = m - l + 1;//表示在该条线段的左半部分上的数有多少和midval相等的数,先假定从[l,m]的数都和midval相等 for(int i = l ; i <= r ; i ++)//这个循环是用来求出leftsame的具体值 { if(val[deep][i] < midval) --leftsame; } int lpos = l,rpos = m + 1;//这里是用来求出下一层val[deep+1][]的值,也就是划分的过程 for(int i = l ; i <= r ; i++)//这里注意这时处理该条线段上的所有数值 { if(i == l) toLeft[deep][i] = 0;//toLeft[][i]表示[tree[rt].l,i-1]有多少数在左部 else toLeft[deep][i] = toLeft[deep][i-1];//这里相当于对toLeft[][]初始化 if(val[deep][i] < midval)//如果小于midval 对toLeft[deep][i] ++ { ++toLeft[deep][i]; val[deep+1][lpos++] = val[deep][i];//把该值放到下一层的左边 } else if(val[deep][i] > midval) { val[deep+1][rpos++] = val[deep][i]; } else//判断和midval相等的数是放在左部还是右部 { if(leftsame >= 0)//这里表示只能放置leftsame个数,多余的都要放到右子树上去 { --leftsame; ++toLeft[deep][i]; val[deep+1][lpos++] = val[deep][i]; } else//放到下层的右子树上 { val[deep+1][rpos++] = val[deep][i]; } } } build(l,m,rt<<1,deep+1); build(m+1,r,rt<<1|1,deep+1); }查询
设定当前区间在线段[s,t]上,这时有区间[s,l-1]有toLeft[][l-1]个数进入下一层的左子树,区间[s,r]有toLeft[][r]个数进入下层的左子树,这时我们能够求出在[l,r]区间有,sum = toLeft[][r] - toLeft[][l-1]个数进入下一层的左子树,那么如果sum>=k则递归到左子树查询,否则递归到右子树。到这里应该都很容易理解。
难点就是现在知道了应该递归到左右子树,那么递归的区间呢?
首先,递归到左子树,那么现在这条线段上的数全部都是上一条线段的数应该进入左子树的,因此,这条线段的左边是上个线段[s,l-1]区间里的toLeft[][l-1]个数,紧接着的就是sum个数(toLeft[][r] - toLeft[][l-1]),所以我们能够得到新的查询区间应该是[s+toLeft[][l-1],s+sum-1],这里-1是为了处理边界问题,值得大家认真思索。
同理,递归到右子树,对于现在这条线段上的数全部都是上一条线段的数应该进入右子树的,这条线段的左边是上条线段[s,l-1]区间里的 lsum = l - 1 - toLeft[][l-1] + 1个数,紧接着就是rsum = r - st - sum个数,所以查询区间应该是[mid + 1 + lsum, mid + 1 + rsum]。
下面给出query代码
int query(int l,int r,int k,int rt,int deep) { if(l == r) return val[deep][l]; //下面就是要确认新的查找区间 int s;//表示[l,r]里在左边的数的个数 int ss;//表示[tree[rt].l,l-1]里在左边的数的个数 if(l == tree[rt].l) { s = toLeft[deep][r]; ss = 0; } else { ss = toLeft[deep][l-1]; s = toLeft[deep][r] - ss; } //注意这里的在左边的数都是和sorted[m]相比的,由此可以得到如果s>=k就去左子树找,相反则去右子树 if(s >= k) { /*进入左子树,该条线段的左边有ss个数及从是从上面[tree[rt].l,l-1]该进入左子树的数继承而来 *接着还应该有toLeft[deep][r] - toLeft[deep][l-1]个数即s个数 *所以可以确定新的查找区间应该是[tree[rt].l+ss,newl+s - 1] */ int newl = tree[rt].l + ss; int newr = newl + s - 1;//这里减1是为了处理边界问题 return query(newl,newr,k,rt<<1,deep+1); } else { /* *进入右子树,该条线段的左边应该是上条线段[tree[rt].l,l-1]应该进入右子树的数,即bb = l - tree[rt].l - ss个数 *接着还应该有上条线段[l,r]应该进入右子树的数,即b = r - l + 1 - s个数 *所以可以确定新的查询区间应该是[mid + 1 + bb,mid + 1 + bb + b - 1],这里的-1同一是为了处理边界问题 */ int m = tree[rt].mid(); int b = r - l + 1 - s;//表示[l,r]在右边的数的个数 int bb = (l - 1) - tree[rt].l + 1 - ss;//表示[tree[rt].l,l-1]在右边的数的个数 int newl = m + 1 + bb; int newr = m + b + bb;//m + r - l + 1 - toLeft[deep][r] + ss - l - tree[rt].l - ss = m+r- return query(newl,newr,k-s,rt<<1|1,deep+1); } }
最后给出完整的C++源码
const int N = 100005; struct Node { int l,r; int mid() { return (l+r)>>1; } } tree[N<<2]; int sorted[N]; int val[20][N],toLeft[20][N]; void build(int l,int r,int rt,int deep) { tree[rt].l = l; tree[rt].r = r; if(l == r) return; int m = tree[rt].mid(); int midval = sorted[m]; int leftsame = m - l + 1;//表示在左子树上有多少和midval相等的数 for(int i = l ; i <= r ; i ++) { if(val[deep][i] < midval) --leftsame; } int lpos = l,rpos = m + 1; for(int i = l ; i <= r ; i++) { if(i == l) toLeft[deep][i] = 0;//toLeft[][i]表示[tree[rt].l,i-1]有多少数在左部 else toLeft[deep][i] = toLeft[deep][i-1];//这里相当于对toLeft[][]初始化 if(val[deep][i] < midval) { ++toLeft[deep][i]; val[deep+1][lpos++] = val[deep][i]; } else if(val[deep][i] > midval) { val[deep+1][rpos++] = val[deep][i]; } else//判断和midval相等的数是放在左部还是右部 { if(leftsame >= 0) { --leftsame; ++toLeft[deep][i]; val[deep+1][lpos++] = val[deep][i]; } else { val[deep+1][rpos++] = val[deep][i]; } } } build(l,m,rt<<1,deep+1); build(m+1,r,rt<<1|1,deep+1); } int query(int l,int r,int k,int rt,int deep) { if(l == r) return val[deep][l]; //下面就是要确认新的查找区间 int s;//表示[l,r]里在左边的数的个数 int ss;//表示[tree[rt].l,l-1]里在左边的数的个数 if(l == tree[rt].l) { s = toLeft[deep][r]; ss = 0; } else { ss = toLeft[deep][l-1]; s = toLeft[deep][r] - ss; } //注意这里的在左边的数都是和sorted[m]相比的,由此可以得到如果s>=k就去左子树找,相反则去右子树 if(s >= k) { /*进入左子树,该条线段的左边有ss个数及从是从上面[tree[rt].l,l-1]该进入左子树的数继承而来 *接着还应该有toLeft[deep][r] - toLeft[deep][l-1]个数即s个数 *所以可以确定新的查找区间应该是[tree[rt].l+ss,newl+s - 1] */ int newl = tree[rt].l + ss; int newr = newl + s - 1;//这里减1是为了处理边界问题 return query(newl,newr,k,rt<<1,deep+1); } else { /* *进入右子树,该条线段的左边应该是上条线段[tree[rt].l,l-1]应该进入右子树的数,即bb = l - tree[rt].l - ss个数 *接着还应该有上条线段[l,r]应该进入右子树的数,即b = r - l + 1 - s个数 *所以可以确定新的查询区间应该是[mid + 1 + bb,mid + 1 + bb + b - 1],这里的-1同一是为了处理边界问题 */ int m = tree[rt].mid(); int b = r - l + 1 - s;//表示[l,r]在右边的数的个数 int bb = (l - 1) - tree[rt].l + 1 - ss;//表示[tree[rt].l,l-1]在右边的数的个数 int newl = m + 1 + bb; int newr = m + b + bb;//m + r - l + 1 - toLeft[deep][r] + ss - l - tree[rt].l - ss = m+r- return query(newl,newr,k-s,rt<<1|1,deep+1); } } static inline int Rint()//这段是整型数的输入外挂,可以忽略不用看 { struct X { int dig[256]; X() { for(int i = '0'; i <= '9'; ++i) dig[i] = 1; dig['-'] = 1; } }; static X fuck; int s = 1, v = 0, c; for (; !fuck.dig[c = getchar()];); if (c == '-') s = 0; else if (fuck.dig[c]) v = c ^ 48; for (; fuck.dig[c = getchar()]; v = v * 10 + (c ^ 48)); return s ? v : -v; } int main() { int n,m; while(~scanf("%d %d",&n,&m)) { for(int i = 1 ; i <= n ; i++) { scanf("%d",&val[0][i]); sorted[i] = val[0][i]; } sort(sorted+1,sorted+n+1); build(1,n,1,0); while(m--) { int a,b,c; scanf("%d %d %d",&a,&b,&c); printf("%d\n",query(a,b,c,1,0)); } } return 0; }