最近打算研究一波树套树,以下分别介绍了树状数组套主席树和线段树套平衡树的原理和简单用法。
众所周知,主席树维护的是一种类似前缀和的结构,每个节点都是包含了之前所有节点值的权值线段树,通过继承上一个节点权值线段树的部分结构以减少大量的空间和时间消耗。
因为维护的是前缀和的结构,因此主席树满足可减性,在解决如静态区间第k小等问题中只需要取区间右端的树减区间左端的树即可得到仅包含有区间内值的权值线段树,这其实就类似于求一个序列的某个区间和可以用前缀和数组,区间右端的值减区间左端的值得到。
当然,以上都是废话,会写主席树的话肯定也知道这些东西。不过我还是要写出来,是因为想展示主席树其实本身也是一种“数据结构套数据结构”的形式,把每个节点的权值线段树抽象成点,主席树的上层就是一个简单的前缀和数组,下层使用权值线段树代替了前缀和数组中的每一个位置。而带修主席树不外乎就是把这个上层结构更换了一下,换成树状数组或线段树之类的其它数据结构。
以下通过两个简单的板子题展示下如何使用树状数组套主席树(其实应该是树状数组套权值线段树
给一个含有n个数的序列,需要支持两种操作:
1、查询下标在 [l, r] 内的第 k 小的数
2、把序列中第 x 个数更改为 y
共进行m次操作, n,m≤105。
相比于静态第k小,多了一个单点修改的操作。
如果直接莽,用普通主席树写,每次修改操作从第i位到最后一位全部改一遍,那必将t的很惨。
要支持单点修改,普通主席树的上层结构:前缀和数组 显然是无法满足的。想到既要支持单点修改,又要支持区间查询的有啥数据结构?答案肯定就是树状数组啦!(为啥这题不用线段树?因为不好写没必要而且占空间太大)
把主席树的整个上层结构换成树状数组,单点修改用树状数组(或其它数据结构)的方法,查询也用树状数组(或其它数据结构)的方法,这就是带修主席树的主要思路了。
既然上层选择了树状数组,那整个树的构建方法肯定也不能和普通主席树相同了。考虑树状数组的写法,每次修改序列中一个位置的值,最多会修改树状数组中log(n)个位置的值。而对一颗权值线段树添加一个值,只需要多开一条新的最多长log(n)的链就行了(开链过程十分类似普通主席树从第 i 个权值线段树构建第 i+1 个权值线段树的过程)。那么整体而言,修改序列一个位置的值在树状数组套主席树的结构中复杂度就是O(log(n)*log(n))。
查询时类似,在上层结构树状数组中,单次查询最多要访问log(n)个点,在下层权值线段树中,最多访问log(n)个点,总体复杂度也是O(log(n)*log(n))的。如何找第k小,相信大家都做过主席树找静态数组第k小,参照那个写法写就行了。
ps:以上为了方便写的都是log(n),其实在树状数组部分log里面确实是n(即序列中数的个数),而在权值线段树部分log里面应该是数字的范围。
注意,此题还需要先离散化一下,还有树状数组套主席树的空间复杂度和时间复杂度都是O(n*log(n)*log(n))的,因此要十分注意空间是否开够。
细节可以参考代码。
#pragma GCC optimize(2)
#include
#define inf 1000000000
#define maxn 101000
using namespace std;
typedef long long ll;
int cnt, root[maxn];
struct Chair
{
int sum, ls, rs;
}tr[40100000];
int n, m, u, a[maxn];
map<int, int> uni, reuni;
struct item
{
int l, r, k;
}opt[maxn];
int update(int pre, int l, int r, int x, int f)
{
int rt=++cnt, mid=(l+r)/2;
tr[rt]=tr[pre], tr[rt].sum+=f;
if (l==r) return rt;
if (x<=mid) tr[rt].ls=update(tr[rt].ls, l, mid, x, f);
else tr[rt].rs=update(tr[rt].rs, mid+1, r, x, f);
return rt;
}
void add(int p, int x, int f)
{
for (int i=p; i<=n; i+=i&-i)
root[i]=update(root[i], 1, u, x, f);
}
int query(vector<int> rt_l, vector<int> rt_r, int l, int r, int k)
{
if (l==r) return l;
int suml=0, sumr=0, mid=(l+r)/2;
for (auto &i: rt_l)
suml+=tr[tr[i].ls].sum;
for (auto &i: rt_r)
sumr+=tr[tr[i].ls].sum;
if (sumr-suml>=k)
{
for (auto &i: rt_l) i=tr[i].ls;
for (auto &i: rt_r) i=tr[i].ls;
return query(rt_l, rt_r, l, mid, k);
}
for (auto &i: rt_l) i=tr[i].rs;
for (auto &i: rt_r) i=tr[i].rs;
return query(rt_l, rt_r, mid+1, r, k-sumr+suml);
}
int main()
{
ios::sync_with_stdio(0); cin.tie(0);
cin>>n>>m;
for (int i=1; i<=n; i++)
{
cin>>a[i];
uni[a[i]]=1;
}
for (int i=1; i<=m; i++)
{
char op;
cin>>op;
if (op=='Q')
cin>>opt[i].l>>opt[i].r>>opt[i].k;
else
{
cin>>opt[i].r>>opt[i].k;
uni[opt[i].k]=1;
}
}
for (auto &i: uni)
{
i.second=++u;
reuni[u]=i.first;
}
for (int i=1; i<=n; i++)
a[i]=uni[a[i]];
for (int i=1; i<=m; i++)
if (!opt[i].l) opt[i].k=uni[opt[i].k];
for (int i=1; i<=n; i++)
add(i, a[i], 1);
for (int i=1; i<=m; i++)
{
if (opt[i].l)
{
vector<int> rt_l, rt_r;
for (int j=opt[i].l-1; j; j-=j&-j)
rt_l.push_back(root[j]);
for (int j=opt[i].r; j; j-=j&-j)
rt_r.push_back(root[j]);
cout<<reuni[query(rt_l, rt_r, 1, u, opt[i].k)]<<"\n";
}
if (!opt[i].l)
{
add(opt[i].r, a[opt[i].r], -1);
add(opt[i].r, opt[i].k, 1);
a[opt[i].r]=opt[i].k;
}
}
return 0;
}
给一个含有n个数的序列,需要支持一下操作:
1、查询k在区间内的排名
2、查询区间内排名为k的值
3、修改某一位值上的数值
4、查询k在区间内的前驱
5、查询k在区间内的后继
共进行m次操作, n,m≤5*104。
相当于是上面那题的升级版,多了三个操作。
找区间排名为k可以直接copy上面那题。
如何查询k在区间内的排名?其实就是找有几个值小于k,可以先考虑普通的主席树怎么做。比如说现在得到了包含这个区间所有点的权值线段树,从树根开始,当前点在权值线段树中代表的空间为 [l, r] ,讨论 k 是否大于区间的中点 mid ,若小于等于mid的话,往左儿子计算。若大于mid的话往往右儿子计算的同时要加上左儿子节点的个数,递归下去就行了。
第4,5个操作,要是专门各写一个函数就又要多好几十行,想想这两个操作其实就是1,2操作的结合,比如查k的前驱,就是找到k的排名,然后查找排第k-1的数是什么,后继类似。
加上这题主要是让大家稍微巩固一下(难题我也不会了
#pragma GCC optimize(2)
#include
#define inf 1000000000
#define maxn 101000
using namespace std;
typedef long long ll;
ll cnt, root[maxn], n, u;
map<ll, ll> uni, reuni;
struct Chair
{
int sum, ls, rs;
}tr[40100000];
int update(int pre, int l, int r, int x, int f)
{
int rt=++cnt, mid=(l+r)/2;
tr[rt]=tr[pre], tr[rt].sum+=f;
if (l==r) return rt;
if (x<=mid) tr[rt].ls=update(tr[rt].ls, l, mid, x, f);
else tr[rt].rs=update(tr[rt].rs, mid+1, r, x, f);
return rt;
}
void add(int p, int x, int f)
{
for (int i=p; i<=n; i+=i&-i)
root[i]=update(root[i], 1, u, x, f);
}
int query_rankk(vector<int> rt_l, vector<int> rt_r, int l, int r, int k)
{
if (l==r) return l;
int suml=0, sumr=0, mid=(l+r)/2;
for (auto &i: rt_l) suml+=tr[tr[i].ls].sum;
for (auto &i: rt_r) sumr+=tr[tr[i].ls].sum;
if (sumr-suml>=k)
{
for (auto &i: rt_l) i=tr[i].ls;
for (auto &i: rt_r) i=tr[i].ls;
return query_rankk(rt_l, rt_r, l, mid, k);
}
for (auto &i: rt_l) i=tr[i].rs;
for (auto &i: rt_r) i=tr[i].rs;
return query_rankk(rt_l, rt_r, mid+1, r, k-sumr+suml);
}
int query_krank(vector<int> rt_l, vector<int> rt_r, int l, int r, int k)
{
int suml=0, sumr=0, mid=(l+r)/2;
if (l==r) return 0;
vector<int> nxl, nxr;
for (auto &i: rt_l) nxl.push_back(tr[i].ls);
for (auto &i: rt_r) nxr.push_back(tr[i].ls);
if (k<=mid) return query_krank(nxl, nxr, l, mid, k);
for (auto &i: rt_l) suml+=tr[tr[i].ls].sum;
for (auto &i: rt_r) sumr+=tr[tr[i].ls].sum;
nxl.clear(), nxr.clear();
for (auto &i: rt_l) nxl.push_back(tr[i].rs);
for (auto &i: rt_r) nxr.push_back(tr[i].rs);
return sumr-suml+query_krank(nxl, nxr, mid+1, r, k);
}
struct item
{
ll f, l, r, k;
}opt[maxn];
ll m, a[maxn];
int main()
{
ios::sync_with_stdio(0); cin.tie(0);
cin>>n>>m;
for (int i=1; i<=n; i++)
{
cin>>a[i];
uni[a[i]]=1;
}
for (int i=1, op, l, r, k; i<=m; i++)
{
cin>>op>>l>>r;
opt[i].f=op, opt[i].l=l, opt[i].r=r;
if (op==3) {uni[opt[i].r]=1; continue;}
cin>>k;
opt[i].k=k;
if (op==5) opt[i].k++;
if (op==1 || op==4 || op==5) uni[opt[i].k]=1;
}
for (auto &i: uni)
{
i.second=++u;
reuni[u]=i.first;
}
for (int i=1; i<=n; i++)
a[i]=uni[a[i]];
for (int i=1; i<=m; i++)
{
int op=opt[i].f;
if (op==1 || op==4 || op==5) opt[i].k=uni[opt[i].k];
else if (op==3) opt[i].r=uni[opt[i].r];
}
for (int i=1; i<=n; i++)
add(i, a[i], 1);
for (int i=1; i<=m; i++)
{
if (opt[i].f==3)
{
add(opt[i].l, a[opt[i].l], -1);
add(opt[i].l, opt[i].r, 1);
a[opt[i].l]=opt[i].r;
continue;
}
vector<int> rt_l, rt_r;
for (int j=opt[i].l-1; j; j-=j&-j)
rt_l.push_back(root[j]);
for (int j=opt[i].r; j; j-=j&-j)
rt_r.push_back(root[j]);
if (opt[i].f==1)
cout<<query_krank(rt_l, rt_r, 1, u, opt[i].k)+1<<"\n";
if (opt[i].f==2)
cout<<reuni[query_rankk(rt_l, rt_r, 1, u, opt[i].k)]<<"\n";
if (opt[i].f==4)
{
int rk=query_krank(rt_l, rt_r, 1, u, opt[i].k)+1;
if (rk==1) cout<<"-2147483647\n";
else cout<<reuni[query_rankk(rt_l, rt_r, 1, u, rk-1)]<<"\n";
}
if (opt[i].f==5)
{
int rk=query_krank(rt_l, rt_r, 1, u, opt[i].k)+1;
if (rk==opt[i].r-opt[i].l+2) cout<<"2147483647\n";
else cout<<reuni[query_rankk(rt_l, rt_r, 1, u, rk)]<<"\n";
}
}
return 0;
}
线段树套平衡树的方法也类似于树状数组套主席树,在解决不同的树套树问题时各有优劣。比如在做下面这道板子题的时候树状数组套主席树就明显更优,因为平衡树对于查询第k大需要二分而多了一个log的复杂度。
将线段树中的每个节点建一棵平衡树,空间复杂度之所以可以保证是因为每个节点的平衡树大小也就和这个节点所包括的区间一样大,因此整体空间复杂度是nlog的(这一点要优于树状数组套主席树)。
修改查询等操作不难,就是按线段树的规矩找到需要操作的区间,然后对这个节点下的平衡树进行操作即可。
题意见上面树状数组套主席树部分。
真难调(可能是我写法比较蠢吧…
而且由于部分操作复杂度是三个log,需要吸氧才能ac所有数据。
判断k在区间内的排名没啥好说的,平衡树的常规操作。查找区间第k小,本来平衡树也是可以支持这个操作的,但由于查询区间内可能包含了多颗平衡树,平衡树之间又没法像主席树那样合并,因此应该是不能直接查询到的,只能先二分答案,然后通过判断排名来check。
第四第五操作依旧可以靠第一第二操作完成,不过貌似可以直接写而不套用第二个操作(因为第二个操作复杂度较高,套用会导致这两个操作复杂度也变高),不过我懒得写了不过反正程序复杂度取决于最高复杂度的操作,吸口氧也能过,就还是直接套用完事。
#pragma GCC optimize(3)
#include
#define inf 1000000000
#define maxn 51000
#define root(i) (tr[i].ch[1])
using namespace std;
typedef long long ll;
struct Splaytr
{
int v, fa, sum, rep, ch[2];
//结点值,父亲,子树元素个数合,该结点元素个数
}tr[8000000];
int spcnt;
//结点总数,元素总数,splay树根
int iden(int x)
{//判断是否为右侧点
return tr[tr[x].fa].ch[1]==x;
}
void pushup(int x)
{
tr[x].sum=tr[tr[x].ch[0]].sum+tr[tr[x].ch[1]].sum+tr[x].rep;
}
void conn(int son, int fa, int lr)
{//连接son和fa,lr表示son是fa的哪个儿子
tr[son].fa=fa, tr[fa].ch[lr]=son;
}
void rota(int x)
{//将x上旋
int fa=tr[x].fa, gfa=tr[fa].fa;
int xr=iden(x), far=iden(fa), nx=tr[x].ch[xr^1];
conn(nx, fa, xr), conn(fa, x, (xr^1)), conn(x, gfa, far);
pushup(fa), pushup(x);
}
void splay(int x, int to)
{//将x上旋至to
int tf=tr[to].fa;
while (tr[x].fa!=tf)
{
int up=tr[x].fa;
if (tr[up].fa==tf) rota(x);
else if (iden(x)==iden(up))
rota(up), rota(x);
else rota(x), rota(x);
}
}
void crepoint(int v, int fa)
{
tr[++spcnt].v=v, tr[spcnt].fa=fa;
tr[spcnt].sum=tr[spcnt].rep=1;
}
int finpoint(int pos, int v)
{
int now=root(pos);
while (true)
{
if (tr[now].v==v)
{
splay(now, root(pos));
return now;
}
now=tr[now].ch[v>tr[now].v];
if (!now) return 0;
}
}
void push(int pos, int v)
{
if (root(pos)==0) {root(pos)=spcnt+1; crepoint(v, pos); return ;}
int now=root(pos);
while (true)
{
tr[now].sum++;
if (v==tr[now].v)
{
tr[now].rep++;
splay(now, root(pos));
return ;
}
int next=(v>tr[now].v);
if (!tr[now].ch[next])
{
crepoint(v, now);
tr[now].ch[next]=spcnt;
splay(spcnt, root(pos));
return ;
}
now=tr[now].ch[next];
}
}
void pop(int pos, int v)
{
int dele=finpoint(pos, v);
if (!dele) return ;
if (tr[dele].rep>1)
{
tr[dele].rep--, tr[dele].sum--;
return ;
}
if (!tr[dele].ch[0])
{
root(pos)=tr[dele].ch[1];
tr[root(pos)].fa=pos;
}
else
{
int dls=tr[dele].ch[0];
while (tr[dls].ch[1]) dls=tr[dls].ch[1];
splay(dls, tr[dele].ch[0]);
int drs=tr[dele].ch[1];
conn(drs, dls, 1), conn(dls, pos, 1);
pushup(dls);
}
}
int krank(int pos, int k)
{
int ans=0, now=root(pos);
while (true)
{
if (now==0) break;
if (k<=tr[now].v) now=tr[now].ch[0];
else
{
ans+=tr[tr[now].ch[0]].sum+tr[now].rep;
now=tr[now].ch[1];
}
}
return ans;
}
int n, m, u;
ll a[maxn];
map<ll, int> uni;
map<int, ll> reuni;
struct item
{
ll f, l, r, k;
}opt[maxn];
void build(int l, int r, int pos)
{
spcnt=max(spcnt, pos);
if (l==r) return ;
build(l, (l+r)/2, pos<<1), build((l+r)/2+1, r, pos<<1|1);
}
void build2(int l, int r, int pos)
{
push(pos, -inf), push(pos, inf);
if (l==r) return ;
build2(l, (l+r)/2, pos<<1), build2((l+r)/2+1, r, pos<<1|1);
}
void add(int x, int v, int f, int l=1, int r=n, int pos=1)
{
if (f==1) push(pos, v);
else pop(pos, v);
if (l==r) return ;
int mid=(l+r)/2;
if (x<=mid) add(x, v, f, l, mid, pos<<1);
else add(x, v, f, mid+1, r, pos<<1|1);
}
int query_krank(int k, int L, int R, int l=1, int r=n, int pos=1)
{
if (l>=L && r<=R)
return krank(pos, k)-1;
int mid=(l+r)/2;
if (R<=mid) return query_krank(k, L, R, l, mid, pos<<1);
else if (L>mid) return query_krank(k, L, R, mid+1, r, pos<<1|1);
else return query_krank(k, L, mid, l, mid, pos<<1)+query_krank(k, mid+1, R, mid+1, r, pos<<1|1);
}
int query_rankk(int k, int L, int R)
{
int l=1, r=u, mid;
while (l<r)
{
mid=(l+r)/2;
if (query_krank(mid, L, R)+1<=k) l=mid+1;
else r=mid;
}
if (query_krank(l, L, R)+1>k) l--;
return l;
}
int main()
{
ios::sync_with_stdio(0); cin.tie(0);
cin>>n>>m;
build(1, n, 1); build2(1, n, 1);
for (int i=1; i<=n; i++)
{
cin>>a[i];
uni[a[i]]=1;
}
for (int i=1, op, l, r, k; i<=m; i++)
{
cin>>op>>l>>r;
opt[i].f=op, opt[i].l=l, opt[i].r=r;
if (op==3) {uni[opt[i].r]=1; continue;}
cin>>k;
opt[i].k=k;
if (op==5) opt[i].k++;
if (op==1 || op==4 || op==5) uni[opt[i].k]=1;
}
for (auto &i: uni)
{
i.second=++u;
reuni[u]=i.first;
}
for (int i=1; i<=n; i++)
a[i]=uni[a[i]];
for (int i=1; i<=m; i++)
{
int op=opt[i].f;
if (op==1 || op==4 || op==5) opt[i].k=uni[opt[i].k];
else if (op==3) opt[i].r=uni[opt[i].r];
}
for (int i=1; i<=n; i++)
add(i, a[i], 1);
for (int i=1; i<=m; i++)
{
if (opt[i].f==3)
{
add(opt[i].l, a[opt[i].l], -1);
add(opt[i].l, opt[i].r, 1);
a[opt[i].l]=opt[i].r;
continue;
}
if (opt[i].f==1)
cout<<query_krank(opt[i].k, opt[i].l, opt[i].r)+1<<"\n";
if (opt[i].f==2)
cout<<reuni[query_rankk(opt[i].k, opt[i].l, opt[i].r)]<<"\n";
if (opt[i].f==4)
{
int rk=query_krank(opt[i].k, opt[i].l, opt[i].r)+1;
if (rk==1) cout<<"-2147483647\n";
else cout<<reuni[query_rankk(rk-1, opt[i].l, opt[i].r)]<<"\n";
}
if (opt[i].f==5)
{
int rk=query_krank(opt[i].k, opt[i].l, opt[i].r)+1;
if (rk==opt[i].r-opt[i].l+2) cout<<"2147483647\n";
else cout<<reuni[query_rankk(rk, opt[i].l, opt[i].r)]<<"\n";
}
}
return 0;
}