Common Substrings POJ - 3415(后缀自动机)

https://vjudge.net/problem/POJ-3415
其实很早以前这道题就过了,但不过因为后缀数组的方法自己也不是很懂就没有写。
今天我学了一下后缀自动机,我们利用后缀自动机解决。
首先对A串建立后缀自动机,然后对每一个非克隆结点进行标记, 然后建一颗fail树,用一个dfs进行计数,这个步骤就是其他博客上写的拓扑排序。

  • 首先对于一个结点,他所包含的子串个数为 l e n [ i ] − l e n k [ l i n k [ i ] ] len[i]-lenk[link[i]] len[i]lenk[link[i]],那么包含长度大于等于k,那么就是 l e n [ i ] − m a x ( k , l e n [ l i n k [ i ] ] + 1 ) + 1 len[i]-max(k,len[link[i]]+1)+1 len[i]max(k,len[link[i]]+1)+1上述的结论以及推导是可以由后缀自动机的定义得到的。

  • 然后我们用B串到自动机上进行匹配,并且记录匹配长度,这里可能有同学会犯错,我就犯错了,我认为匹配的状态的len数组和匹配的长度是一样的,因此就这样标记了,然后会多算。我们继续,如果当匹配长度大于等于k时,就可以 a n s + = ( l e n [ i ] − m a x ( k , l e n [ l i n k [ i ] ] + 1 ) + 1 ) ∗ d 1 [ i ] ans+=(len[i]-max(k,len[link[i]]+1)+1)*d1[i] ans+=(len[i]max(k,len[link[i]]+1)+1)d1[i]这里的d1数组就是上面的dfs计数得到的。然后如果是 l e n [ l i n k [ i ] ] len[link[i]] len[link[i]]大于等于k那么就要对d2进行计数了,因为你不管你是从那一个儿子上来的,只要父亲满足上面那个条件,总会找到一个满足的串。

  • 最后就仿造上面的步骤,按着刚刚的那个方法在fail树上进行计数就可以了。
    整体来看SAM的效率还是挺高的,但不过确实不好理解。

//#include "bits/stdc++.h"
#include 
#include 
#include 
#include 
#include 

using namespace std;
//inline int read() {
     
//    int x = 0;
//    bool f = 1;
//    char c = getchar();
//    for (; !isdigit(c); c = getchar()) if (c == '-') f = 0;
//    for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0';
//    if (f) return x;
//    return 0 - x;
//}
typedef long long ll;
const int maxn = 110000 + 10;
int getc(char c) {
     
    if (c >= 'a' && c <= 'z') return c - 'a';
    return c - 'A' + 26;
}
int len[maxn << 1], link[maxn << 1];
int ch[maxn << 1][60], sz, last;
int d1[maxn << 1], d2[maxn << 1];
void init() {
     
    len[0] = 0;
    link[0] = -1;
    sz = 1;
    last = 0;
    memset(d1, 0, sizeof(d1));
    memset(d2, 0, sizeof(d2));
    memset(ch[0], 0, sizeof(ch[0]));
}

void extend(int c) {
     
    int cur = sz++, p = last;
    len[cur] = len[last] + 1;
    d1[cur] = 1;
    memset(ch[cur], 0, sizeof(ch[cur]));
    while (p != -1 && !ch[p][c]) {
     
        ch[p][c] = cur;
        p = link[p];
    }
    if (p == -1) {
     
        link[cur] = 0;
    } else {
     
        int q = ch[p][c];
        if (len[p] + 1 == len[q]) {
     
            link[cur] = q;
        } else {
     
            int clone = sz++;
            len[clone] = len[p] + 1;
            memcpy(ch[clone], ch[q], sizeof(ch[q]));
            link[clone] = link[q];
            while (p != -1 && ch[p][c] == q) {
     
                ch[p][c] = clone;
                p = link[p];
            }
            link[q] = link[cur] = clone;
        }
    }
    last = cur;
}
struct node {
     
    int v, next;
} ed[maxn << 1];
int head[maxn << 1], cnt = 0;
void add_edge(int u, int v) {
     
    ++cnt;
    ed[cnt].v = v;
    ed[cnt].next = head[u];
    head[u] = cnt;
}
void dfs(int u) {
     
    for (int i = head[u]; i; i = ed[i].next) {
     
        int v = ed[i].v;
        dfs(v);
        d1[u] += d1[v];
    }
}

int k;
char a[maxn], b[maxn];
ll ans;

void dfs2(int u) {
     
    for (int i = head[u]; i; i = ed[i].next) {
     
        int v = ed[i].v;
        dfs2(v);
        if (len[u] >= k) d2[u] += d2[v];
    }
    ans += 1ll * d1[u] * d2[u] * (len[u] - max(k, len[link[u]] + 1) + 1);
}

int main() {
     
    while (~scanf("%d", &k) && k) {
     
        init();
        memset(head, 0, sizeof(head));
        cnt = 0;
        scanf("%s%s", a, b);
        int lena = strlen(a);
        int lenb = strlen(b);
        for (int i = 0; i < lena; i++) {
     
            extend(getc(a[i]));
        }
        for (int i = 1; i < sz; i++) {
     
            add_edge(link[i], i);
        }
        dfs(0);
        ans = 0;
        int p = 0, nowlen = 0;
        for (int i = 0; i < lenb; i++) {
     
            int id = getc(b[i]);
            if (ch[p][id]) p = ch[p][id], nowlen++;
            else {
     
                while (p != -1 && !ch[p][id]) p = link[p];
                if (p == -1) nowlen = 0, p = 0;
                else nowlen = len[p] + 1, p = ch[p][id];
            }
            if (nowlen >= k) {
     
                ans += 1ll * (nowlen - max(k, len[link[p]] + 1) + 1) * d1[p];
                if (len[link[p]] >= k) d2[link[p]]++;
            }
        }
        dfs2(0);
        cout << ans << endl;
    }
    return 0;
}

你可能感兴趣的:(字符串)