可持久化线段树笔记

可持久化数据结构主要解决有查询历史版本或者返回历史版本的操作。
可持久化线段树就是一种可持久化数据结构。
最简单的可持久化线段树的方法是对于不同的时间,都建一棵新线段树,当前时刻的线段树可以由前一时刻复制来,然后在当前时刻的线段树上面进行修改。
然而这样的可持久化线段树非常耗费内存,有一种感性的想法,就是如果当前时刻的线段树和前一时刻的线段树有共用的节点,这样这些节点就可以不用复制,直接把要修改的节点额外接在前一时刻的线段树上,这样就可以耗费比较小的内存占用而存储每个时间的线段树了。
发现要修改的节点从根开始,一直到树的内部。约为logn个。
此时这棵线段树不具备二叉树的节点编号性质,需要额外存储左右儿子的指针。
为了节约内存,一般不再保存l,r,len这些信息,而是直接算出来。
并且对于每个时间的线段树都要保存一个根。
然后再来看以前的线段树的操作:
1、单点修改
从root[now]这个根进去,然后就和普通线段树差不多了。
2.单点查询
从root[now]这个根进去,然后就和普通线段树差不多了。
3.区间修改
从root[now]这个根进去,往下走需要pushdown,找到区间打标记,往上递归的时候需要pushup
这两个操作记得新开节点。
4.区间查询
从root[now]这个根进去,往下走需要pushdown,记得新开节点。
然后新开节点在某些卡内存的题里面非常坑,可以不用pushup、pushdown解决区间修改+区间查询。
只需要在对应区间打上标记,然后不要下传。
修改往下走的时候在每个区间都要即时修改sum
查询往下走的时候在每个经过的区间都要加上经过的区间的标记。

void add(int &now,int l,int r,int val,int L,int R)
{
    rec[++cnt]=rec[now];
    now=cnt;
    rec[now].sum+=1ll*val*(r-l+1);
    if(l==L&&r==R)
    {
        rec[now].lazy+=val;
        return;
    }
    if(r<=(L+R)/2)add(rec[now].lc,l,r,val,L,(L+R)/2);
    else if(l>(L+R)/2)add(rec[now].rc,l,r,val,(L+R)/2+1,R);
    else
    {
        add(rec[now].lc,l,(L+R)/2,val,L,(L+R)/2);
        add(rec[now].rc,(L+R)/2+1,r,val,(L+R)/2+1,R);
    }
}
ll query(int now,int l,int r,int L,int R)
{
    if(l==L&&r==R)return rec[now].sum;
    ll tmp=1ll*(r-l+1)*rec[now].lazy;
    if(r<=(L+R)/2)return tmp+query(rec[now].lc,l,r,L,(L+R)/2);
    if(l>(L+R)/2)return tmp+query(rec[now].rc,l,r,(L+R)/2+1,R);
    return tmp+query(rec[now].lc,l,(L+R)/2,L,(L+R)/2)+query(rec[now].rc,(L+R)/2+1,r,(L+R)/2+1,R);
}

然后要是返回以前的版本,就用root[pre]就可以了。
贴一个HDU 4348可持久化线段树的笔记。
区间修改+区间查询+可持久化,注意有点卡内存
时间复杂度nlogn,空间复杂度nlogn

#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
using namespace std;
struct node{
    ll sum,lazy;
    int lc,rc;
}rec[3000010];
int n,m,cnt,x,y,z,now,a[100010],rt[100010];
char op[3];
void build(int &now,int l,int r)
{
    rec[++cnt]=rec[now];
    now=cnt;
    if(l==r)
    {
        rec[now].sum=a[l];
        return;
    }
    build(rec[now].lc,l,(l+r)/2);
    build(rec[now].rc,(l+r)/2+1,r);
    rec[now].sum=rec[rec[now].lc].sum+rec[rec[now].rc].sum;
}
void add(int &now,int l,int r,int val,int L,int R)
{
    rec[++cnt]=rec[now];
    now=cnt;
    rec[now].sum+=1ll*val*(r-l+1);
    if(l==L&&r==R)
    {
        rec[now].lazy+=val;
        return;
    }
    if(r<=(L+R)/2)add(rec[now].lc,l,r,val,L,(L+R)/2);
    else if(l>(L+R)/2)add(rec[now].rc,l,r,val,(L+R)/2+1,R);
    else
    {
        add(rec[now].lc,l,(L+R)/2,val,L,(L+R)/2);
        add(rec[now].rc,(L+R)/2+1,r,val,(L+R)/2+1,R);
    }
}
ll query(int now,int l,int r,int L,int R)
{
    if(l==L&&r==R)return rec[now].sum;
    ll tmp=1ll*(r-l+1)*rec[now].lazy;
    if(r<=(L+R)/2)return tmp+query(rec[now].lc,l,r,L,(L+R)/2);
    if(l>(L+R)/2)return tmp+query(rec[now].rc,l,r,(L+R)/2+1,R);
    return tmp+query(rec[now].lc,l,(L+R)/2,L,(L+R)/2)+query(rec[now].rc,(L+R)/2+1,r,(L+R)/2+1,R);
}
int main()
{
    while(~scanf("%d%d",&n,&m))
    {
        now=cnt=0;
        for(int i=1;i<=n;++i)
            scanf("%d",&a[i]);
        build(rt[0],1,n);
        for(int i=1;i<=m;++i)
        {
            scanf("%s",op);
            if(op[0]=='C')
            {
                scanf("%d%d%d",&x,&y,&z);
                ++now;
                add(rt[now]=rt[now-1],x,y,z,1,n);
            }
            else if(op[0]=='Q')
            {
                scanf("%d%d",&x,&y);
                printf("%I64d\n",query(rt[now],x,y,1,n));
            }
            else if(op[0]=='H')
            {
                scanf("%d%d%d",&x,&y,&z);
                printf("%I64d\n",query(rt[z],x,y,1,n));
            }
            else scanf("%d",&now);
        }
    }
}

然后可持久化线段树可以解决一个经典问题:动态区间第k大。
先来看一个不带修改的动态区间第k大。
按顺序把区间内的每一个数都插到可持久化线段树里面。每插入一次时间版本就要改变。
那么,对于根为root[r]和root[l-1]的两棵线段树,一棵保存的是插入了l-1个数的线段树,一个保存的是插入了r个数的线段树,对于这两棵线段树,如果两棵线段树的左子树的size之差为t,就说明在[l,r]之间有t个数小于等于mid。mid是值域的一半。利用这个性质,区间第k大就可以做了。
还有和这个类似的一些变式:
1、查询[l,r]之间小于等于x的数的个数:线段树里面二分查找x,走向右子树的话就要加上两棵左子树size的差,走到叶子就要加上两棵树这个叶子sz的差。
2、查询树上u到v的简单路径上的点权第k大:对于每个点,都维护一个这个点到根的可持久化线段树。类比动态区间第k大里面,对于每个数,都维护一个这个数到第一个数的可持久化线段树。设u和v的lca为t,那么根为root[u],root[v],root[t],root[fa[t]]的这四棵线段树,设sum=size[lc[root[u]]]+size[lc[root[v]]]-size[lc[root[t]]]-size[lc[root[fa[t]]]],那么sum代表什么呢?sum就代表u到v的简单路径上有sum个点小于等于mid。
贴个代码SPOJ COT
查询树上u到v的简单路径上的点权第k大
时间复杂度logn,空间复杂度logn

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
struct node{
    int v,next;
}edge[200010];
int lc[2000010],rc[2000010],cnt[2000010],sorted[100010],a[100010],rt[100010],fa[100010],son[100010],sz[100010],top[100010],dep[100010],head[100010],n,m,u,v,k,len,tot,ecnt;
void addedge(int u,int v)
{
    edge[ecnt].v=v;
    edge[ecnt].next=head[u];
    head[u]=ecnt++;
}
void insert(int &now,int pos,int l,int r)
{
    ++tot;
    lc[tot]=lc[now];
    rc[tot]=rc[now];
    cnt[tot]=cnt[now]+1;
    now=tot;
    if(l==r)return;
    if(pos<=(l+r)/2)insert(lc[now],pos,l,(l+r)/2);
    else insert(rc[now],pos,(l+r)/2+1,r);
}
void dfs1(int u,int f)
{
    sz[u]=1;
    int maxch=0;
    for(int i=head[u];~i;i=edge[i].next)
    {
        int v=edge[i].v;
        if(v!=f)
        {
            dep[v]=dep[u]+1;
            fa[v]=u;
            dfs1(v,u);
            sz[u]+=sz[v];
            if(maxch<sz[v])
            {
                maxch=sz[v];
                son[u]=v;
            }
        }
    }
}
void dfs2(int u,int t)
{
    int pos=lower_bound(sorted+1,sorted+len+1,a[u])-sorted;
    insert(rt[u]=rt[fa[u]],pos,1,len);
    top[u]=t;
    if(son[u]==0)return;
    dfs2(son[u],t);
    for(int i=head[u];~i;i=edge[i].next)
        if(edge[i].v!=son[u]&&edge[i].v!=fa[u])
            dfs2(edge[i].v,edge[i].v);
}
int lca(int u,int v)
{
    while(top[u]!=top[v])
    {
        if(dep[top[u]]<dep[top[v]])swap(u,v);
        u=fa[top[u]];
    }
    return dep[u]<dep[v]?u:v;
}
int find(int u,int v,int t,int fat,int l,int r,int k)
{
    if(l==r)return l;
    int sum=cnt[lc[u]]+cnt[lc[v]]-cnt[lc[t]]-cnt[lc[fat]];
    if(k<=sum)return find(lc[u],lc[v],lc[t],lc[fat],l,(l+r)/2,k);
    else return find(rc[u],rc[v],rc[t],rc[fat],(l+r)/2+1,r,k-sum);
}
int main()
{
    memset(head,-1,sizeof head);
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;++i)
        scanf("%d",&a[i]),sorted[i]=a[i];
    sort(sorted+1,sorted+n+1);
    len=unique(sorted+1,sorted+n+1)-sorted-1;
    for(int i=1;i<n;++i)
    {
        scanf("%d%d",&u,&v);
        addedge(u,v);
        addedge(v,u);
    }
    fa[1]=0;
    dep[1]=1;
    dfs1(1,-1);
    dfs2(1,1);
    for(int i=1;i<=m;++i)
    {
        scanf("%d%d%d",&u,&v,&k);
        int t=lca(u,v);
        printf("%d\n",sorted[find(rt[u],rt[v],rt[t],rt[fa[t]],1,len,k)]);
    }
}

然后再来看可以单点修改的动态区间第k大。
此时光用可持久化线段树不行了,因为修改一个节点,以后的很多历史版本都会被修改。
考虑以前我们做动态区间第k大的时候,可持久化线段树每次都是从前一个时间修改而来的,非常像求了一个前缀和。因此,当有一棵线段树被修改的时候,还有大约n个从那棵线段树累加上来的线段树需要修改。
因此我们想到使用树状数组。
就像以前树状数组求前缀和那样,这次求前缀和的对象换成了可持久化线段树,这样时间多耗费了一个logn的时间。但是后来的修改也只需要logn。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
int lc[10000010],rc[10000010],cnt[10000010],data[10010][4],rt[50010],a[50010],lseg[510],rseg[510],sorted[60010],n,m,tot,len;
char op[3];
int lowbit(int x)
{
    return x&-x;
}
int newnode(int c,int l,int r)
{
    ++tot;
    lc[tot]=l;
    rc[tot]=r;
    cnt[tot]=c;
    return tot;
}
void build(int &x,int l,int r)
{
    x=newnode(0,0,0);
    if(l==r)return;
    build(lc[x],l,(l+r)/2);
    build(rc[x],(l+r)/2+1,r);
}
int find(int l,int r,int k)
{
    if(l==r)return l;
    int sum=0;
    for(int i=1;i<=rseg[0];++i)
        sum+=cnt[lc[rseg[i]]];
    for(int i=1;i<=lseg[0];++i)
        sum-=cnt[lc[lseg[i]]];
    if(k<=sum)
    {
        for(int i=1;i<=rseg[0];++i)
            rseg[i]=lc[rseg[i]];
        for(int i=1;i<=lseg[0];++i)
            lseg[i]=lc[lseg[i]];
        return find(l,(l+r)/2,k);
    }
    else
    {
        for(int i=1;i<=rseg[0];++i)
            rseg[i]=rc[rseg[i]];
        for(int i=1;i<=lseg[0];++i)
            lseg[i]=rc[lseg[i]];
        return find((l+r)/2+1,r,k-sum);
    }
}
void insert(int pre,int &now,int pos,int val,int l,int r)
{
    now=newnode(cnt[pre]+val,lc[pre],rc[pre]);
    if(l==r)return;
    if(pos<=(l+r)/2)insert(lc[pre],lc[now],pos,val,l,(l+r)/2);
    else insert(rc[pre],rc[now],pos,val,(l+r)/2+1,r);
}
void modify(int now,int pos,int val)
{
    int tmp;
    while(now<=n)
    {
        insert(rt[now],tmp,pos,val,1,len);
        rt[now]=tmp;
        now+=lowbit(now);
    }
}
int query(int l,int r,int k)
{
    lseg[0]=rseg[0]=0;
    while(l)
    {
        lseg[++lseg[0]]=rt[l];
        l-=lowbit(l);
    }
    while(r)
    {
        rseg[++rseg[0]]=rt[r];
        r-=lowbit(r);
    }
    return find(1,len,k);
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;++i)
        scanf("%d",&a[i]),sorted[++sorted[0]]=a[i];
    for(int i=1;i<=m;++i)
    {
        scanf("%s",op);
        if(op[0]=='Q')
        {
            data[i][0]=0;
            scanf("%d%d%d",&data[i][1],&data[i][2],&data[i][3]);
        }
        else
        {
            data[i][0]=1;
            scanf("%d%d",&data[i][1],&data[i][2]);
            sorted[++sorted[0]]=data[i][2];
        }
    }
    sort(sorted+1,sorted+sorted[0]+1);
    len=unique(sorted+1,sorted+sorted[0]+1)-sorted-1;
    for(int i=1;i<=n;++i)
        a[i]=lower_bound(sorted+1,sorted+len+1,a[i])-sorted;
    build(rt[0],1,len);
    for(int i=1;i<=n;++i)
        modify(i,a[i],1);
    for(int i=1;i<=m;++i)
    {
        if(!data[i][0])printf("%d\n",sorted[query(data[i][1]-1,data[i][2],data[i][3])]);
        else
        {
            modify(data[i][1],a[data[i][1]],-1);
            a[data[i][1]]=lower_bound(sorted+1,sorted+len+1,data[i][2])-sorted;
            modify(data[i][1],a[data[i][1]],1);
        }
    }
}

你可能感兴趣的:(数据结构)