学习笔记 后缀平衡树简要小结(附例题)

定义

后缀平衡树,简单的说就是动态的维护后缀数组,能做到在 O(logn) 插入, O(1) 查询 rank O(logn) 查询 SA 。当然由于后缀平衡树是支持对后缀的操作,所以要求插入操作只能在字符串开头插入字符(相当于插入一个后缀)。

离线构造

根据定义,后缀平衡树就是把后缀数组构成一棵平衡树,所以只需先构出后缀数组再构后缀平衡树。

在线构造

由于后缀平衡树只能支持在开头增加字符,所以我们就只讨论这种情况。
方案一:
现在我们需要一种能比较两个后缀大小的方法,最简单的就是二分+Hash, O(logn) 的实现这个操作。加上在平衡树上插入的复杂度,总的插入的复杂度就是 O(log2n)

方案二:
我们考虑另外一种比较方法,由于我们每次只增加一个字符,也就是说如果我们把第一个字符删掉,剩下的字符串在之前已经插入过后缀平衡树中,我们只需要线比较一下两个字符串的第一个字符,后面字符串的比较直接调用之前处理好的信息就可以了。

那么现在的问题就变成了怎么快速的比较后缀平衡树中两个后缀的大小。我们考虑对每个节点对应一个区间 (l,r) ,令节点 i tagi=l+r2 ,它的左子树对应的区间是 (l,tagi) ,右子树对应的区间是 (tagi,r) ,容易发现我们比较两个后缀大小时只用比较它对应节点的 tag 值就好了。而且平衡树的深度是 O(logn) 级别的,所以说一个节点对应的区间不会很大,如果用整数类型表达的话基本 longlong 类型就可以表示出来。

加入一个点后由于要维护对应的区间和 tag 值,普通的平衡树就维护不了了,所以要用到一种更高级的平衡树——重量平衡树(其实就是用复杂度证明的暴力),如替罪羊树,treap等。

应用

询问rank_i值

O(longn) 的时间查看对应节点在后缀平衡树中有多少个比后缀 i 小的节点即可。

询问 SAi

O(logn) 的时间查看排在位置 i 的节点对应的后缀。

维护height数组

当我们加入一个后缀时,需要重新维护的 height 值只用两个。一个是插入后缀对应的值,一个是插入后缀在后缀数组中下一个位置对应的值。我们可以先在平衡树中找到跟维护相关的后缀是哪些,提取出来用二分+Hash求 lcp 作为新的 height 值。

删除第一个字符,即删除一个后缀

现在后缀平衡树中找到这个后缀对应得节点,假设我们用的重量平衡树是 treap ,我们只需把删除节点的儿子节点想可持久化 treap 一样合并起来,对 height 值重新维护一下就可以了。

具体实现可以在例题中看到。

例题:JZOJ4384. 【GDOI2016模拟3.14】hashit

题目大意

你有一个字符串 S ,一开始为空串,要求支持两种操作:
1. 在 S 后面加入字母 c
2. 删除 S 最后一个字母。
现在有 Q 个操作,问每次操作后 S 有多少个两两不同的连续子串。

Q105

解题思路

当然这题比较简单的是离线构Trie跑SAM,但是更好的掌握后缀平衡树,这里就提供一种用到后缀平衡树的方法。

那么假如用后缀平衡树的话,这题就变成了裸题。由于题目要求是在末尾加,为了满足后缀平衡树的要求,我们只需把整个串翻转一下就可以变成只会对开头操作。这题要求的是不同的子串的数目,是一个很经典的可以用后缀数组解决的问题,这里就不在阐述。现在只是把问题变成了动态,那么我们只需动态的维护 height 值就可以了,就是一道模板题。

突然发现我这题打的是带log的,O(1)比较的在下面。。。

log比较

//YxuanwKeith
#include 
#include 
#include 
#include 

using namespace std;
typedef unsigned long long LL;

const int MAXN = 1e5 + 5;
const int Pri = 9973;
const LL Inf = 1ll << 62;

struct SuffixBalanceTree {
    LL tag;
    int Son[2], Size, fix;
} Tr[MAXN];

int tot, Root, Lcp[MAXN];
char S[MAXN];
LL Del, Pow[MAXN], Has[MAXN];

bool Cmp(int x, int y) {
    return S[x] < S[y] || S[x] == S[y] && Tr[x - 1].tag < Tr[y - 1].tag;
}

void Clear(int x, LL l, LL r) {
    Tr[x].fix = rand();
    Tr[x].Son[0] = Tr[x].Son[1] = 0;
    Tr[x].Size = 1;
    Tr[x].tag = (l + r) >> 1;
}

void Update(int Now, LL l, LL r) {
    if (!Now) return;
    Tr[Now].tag = (l + r) >> 1;
    LL Mid = (l + r) >> 1;
    Update(Tr[Now].Son[0], l, Mid), Update(Tr[Now].Son[1], Mid + 1, r);
    Tr[Now].Size = Tr[Tr[Now].Son[0]].Size + Tr[Tr[Now].Son[1]].Size + 1;
}

void Rotate(int Now, int Pre, LL l, LL r) {
    int Side = Tr[Pre].Son[1] == Now;
    Tr[Pre].Son[Side] = Tr[Now].Son[Side ^ 1];
    Tr[Now].Son[Side ^ 1] = Pre;
    Update(Now, l, r);
}

void Insert(int &Now, LL l, LL r) {
    if (!Now) {
        Now = tot;
        Clear(Now, l, r);
        return;
    }
    Tr[Now].Size ++;
    LL Mid = (l + r) >> 1;
    if (Cmp(tot, Now)) Insert(Tr[Now].Son[0], l, Mid); else
        Insert(Tr[Now].Son[1], Mid + 1, r); 
    if (Tr[tot].fix > Tr[Now].fix) {
        Rotate(tot, Now, l, r);
        Now = tot;
    }
}

int GetRank(int Now) {
    if (Now == tot) return Tr[Tr[Now].Son[0]].Size + 1;
    if (Cmp(tot, Now)) return GetRank(Tr[Now].Son[0]); 
    return Tr[Tr[Now].Son[0]].Size + 1 + GetRank(Tr[Now].Son[1]);
}

int Search(int Now, int rk) {
    if (!Now) return Now;
    if (Tr[Tr[Now].Son[0]].Size + 1 == rk) return Now;
    if (Tr[Tr[Now].Son[0]].Size >= rk) return Search(Tr[Now].Son[0], rk); 
    return Search(Tr[Now].Son[1], rk - Tr[Tr[Now].Son[0]].Size - 1);
}

bool Same(int l1, int l2, int len) {
    int r1 = l1 + len - 1, r2 = l2 + len - 1;
    return Has[r1] - Has[l1 - 1] * Pow[len] == Has[r2] - Has[l2 - 1] * Pow[len];
}

int TreatLcp(int x, int y) {
    int l = 1, r = min(x, y), Ans = 0;
    while (l <= r) {
        int Mid = (l + r) >> 1;
        if (Same(x - Mid + 1, y - Mid + 1, Mid)) Ans = Mid, l = Mid + 1; else 
            r = Mid - 1;
    }
    Lcp[y] = Ans;
}

void Add(char c) {
    S[++ tot] = c;
    Has[tot] = Has[tot - 1] * Pri + S[tot] - 'a' + 1;
    Insert(Root, 1, Inf);
    int rk = GetRank(Root);
    int l = Search(Root, rk - 1), r = Search(Root, rk + 1);
    Del = Del - Lcp[r];
    TreatLcp(l, tot), TreatLcp(tot, r);
    Del = Del + Lcp[tot] + Lcp[r];
}

int Merge(int a, int b) {
    if (!a) return b;
    if (!b) return a;
    if (Tr[a].fix < Tr[b].fix) {
        Tr[b].Son[0] = Merge(a, Tr[b].Son[0]);
        return b;
    }
    Tr[a].Son[1] = Merge(Tr[a].Son[1], b);
    return a;
}

void Out(int &Now, LL l, LL r) {
    if (tot == Now) {
        Now = Merge(Tr[Now].Son[0], Tr[Now].Son[1]);
        Update(Now, l, r);
    } else {
        Tr[Now].Size --;
        LL Mid = (l + r) >> 1;
        if (Cmp(tot, Now)) Out(Tr[Now].Son[0], l, Mid); else 
            Out(Tr[Now].Son[1], Mid + 1, r);
    }
}

void Delete() {
    int rk = GetRank(Root);
    int l = Search(Root, rk - 1), r = Search(Root, rk + 1);
    Del = Del - Lcp[tot] - Lcp[r];
    TreatLcp(l, r);
    Del = Del + Lcp[r];
    Out(Root, 1, Inf);
    tot --;
}

int main() {
    scanf("%s", S + 1);
    int len = strlen(S + 1);
    Pow[0] = 1;
    for (int i = 1; i <= len; i ++) Pow[i] = Pow[i - 1] * Pri;
    for (int i = 1; i <= len; i ++) {
        if (S[i] == '-') Delete(); else 
            Add(S[i]);
        printf("%lld\n", 1ll * tot * (tot + 1) / 2 - Del);
    }
}

O(1)比较

#include 
#include 
#include 
#include 

using namespace std;

const int maxn=100005,mo=5371297,pri=832189,M[2]={67,71};

typedef long long LL;

typedef unsigned long long ULL;

const LL Inf=(LL)1<<60;

int n,root,fa[maxn],son[maxn][2],fix[maxn],len,pre[maxn],nxt[maxn],height[maxn];

LL l[maxn],r[maxn],rank[maxn],ans;

ULL Hash[2][maxn],Power[2][maxn];

char s[maxn],c[maxn];

void rebuild(int x,LL L,LL R)
{
    if (!x) return;
    l[x]=L; r[x]=R; rank[x]=l[x]+r[x];
    rebuild(son[x][0],l[x],rank[x]>>1);
    rebuild(son[x][1],rank[x]>>1,r[x]);
}

void Rotate(int x,int t,LL l,LL r)
{
    int y=fa[x];
    if (y==root) root=x;else
    {
        if (son[fa[y]][0]==y) son[fa[y]][0]=x;else son[fa[y]][1]=x;
    }
    fa[x]=fa[y];
    son[y][t]=son[x][t^1]; fa[son[y][t]]=y; son[x][t^1]=y; fa[y]=x;
    rebuild(x,l,r);
}

bool cmp(int x,int y)
{
    return s[x]1]1];
}

int ran(int x)
{
    return (LL)(s[x]+x)*pri%mo;
}

bool Same(int x,int y,int l)
{
    if (xreturn 0;
    for (int j=0;j<2;j++)
        if (Hash[j][x]-Hash[j][x-l]*Power[j][l]!=Hash[j][y]-Hash[j][y-l]*Power[j][l]) return 0;
    return 1;
}

int lcp(int x,int y)
{
    if (x>y) x^=y^=x^=y;
    int l,r,mid;
    for (l=1,r=x+1,mid=l+r>>1;l>1)
        if (Same(x,y,mid)) l=mid+1;else r=mid;
    return l-1;
}

void update(int x)
{
    ans-=nxt[x]-height[nxt[x]];
    height[x]=lcp(x,pre[x]);
    height[nxt[x]]=lcp(nxt[x],x);
    ans+=x-height[x]+nxt[x]-height[nxt[x]];
}

void insert(int x,int i,LL l,LL r)
{
    LL mid=l+r>>1;
    if (cmp(x,i))
    {
        if (son[i][0]) insert(x,son[i][0],l,mid);
        else
        {
            son[i][0]=x; fa[x]=i; rebuild(x,l,mid);
            pre[x]=pre[i]; nxt[pre[x]]=x;
            nxt[x]=i; pre[i]=x;
        }
        if (fix[i]>fix[son[i][0]]) Rotate(son[i][0],0,l,r);
    }else
    {
        if (son[i][1]) insert(x,son[i][1],mid,r);
        else
        {
            son[i][1]=x; fa[x]=i; rebuild(x,mid,r);
            nxt[x]=nxt[i]; pre[nxt[x]]=x;
            pre[x]=i; nxt[i]=x;
        }
        if (fix[i]>fix[son[i][1]]) Rotate(son[i][1],1,l,r);
    }
}

void Delete(int x)
{
    while (son[x][0]>0 || son[x][1]>0)
        if (!son[x][0] || son[x][1]>0 && fix[son[x][0]]>fix[son[x][1]]) Rotate(son[x][1],1,l[x],r[x]);
        else Rotate(son[x][0],0,l[x],r[x]);
    if (root==x) root=0;else
    {
        ans-=x-height[x]+nxt[x]-height[nxt[x]];
        if (x==son[fa[x]][0]) son[fa[x]][0]=0;else son[fa[x]][1]=0;
        pre[nxt[x]]=pre[x]; nxt[pre[x]]=nxt[x];
        height[nxt[x]]=min(height[nxt[x]],height[x]);
        ans+=nxt[x]-height[nxt[x]];
    }
}

int main()
{
    scanf("%s",c+1);
    n=strlen(c+1);
    Power[0][0]=Power[1][0]=1;
    for (int i=1;i<=n;i++)
        for (int j=0;j<2;j++) Power[j][i]=Power[j][i-1]*M[j];
    for (int i=1;i<=n;i++)
    {
        if (c[i]=='-') Delete(len--);else
        {
            s[++len]=c[i];
            fix[len]=ran(len);
            fa[len]=son[len][0]=son[len][1]=0;
            for (int j=0;j<2;j++) Hash[j][len]=Hash[j][len-1]*M[j]+s[len]-'a';
            if (len==1)
            {
                root=1;
                l[1]=0; rank[1]=r[1]=Inf;
                ans=1;
                printf("1\n");
                continue;
            }
            insert(len,root,0,Inf);
            update(len);
        }
        printf("%lld\n",ans);
    }
    return 0;
}

你可能感兴趣的:(算法-String)