[HAOI2016]找相同字符

Description

给定长度分别为 \(n\), \(m\) 的两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。

\(n,m \le 2\times 10^5\)

Solution

\(yyt\)的题,考试时并不会后缀自动机,于是只能对一个串把所有后缀插入\(AC\)自动机,另一个串在上面跑,每次到一个节点就暴力跳\(fail\),记\(cnt[x]\)为自动机\(x\)点被多少个后缀经过,把\(fail\)\(cnt\)加到答案里。

这样做正确性在于对B串所有后缀放入AC自动机后,A串在上面跑,到达一个点x表示A串的一个前缀的后缀和B串若干个后缀的前缀匹配上了(即若干个子串),由于是匹配的是最长的一段所以每次还要跳\(fail\),把\(fail\)的答案累计上,下面是核心代码:

for (int i = 1; i <= n; ++i)
{
    p = ch[p][A[i] - 'a'];
    for (int t = p; t; t = fail[t])
        ans += cnt[t];
}

然而这样空间爆炸(时间也爆炸,但在超时前空间已经爆了),于是就用后缀自动记来接受所有后缀。

后缀自动机的做法也是依赖上面的原理的,对B串建好SAM后A串仍然在B串的SAM上跑,到达一个点x代表A的一段前缀的后缀和B的若干个子串匹配上了我们记录当前的匹配的串长L,那么我们需要统计x接受了多少长度小于等于L的子串,那么就是 parent树上从x的父亲到根路径上所有点的接受串种类数×接受串的个数(endpos集合大小)之和+(L-x父亲的maxlen)×点x接受串个数(注意由于L可能会把x接受的后缀劈开,所以要单独加一下)。

前者预处理即可。

再提供一种 \(SA\) 的做法:

不难发现答案=A+B串自己匹配自己个数-A串自己匹配自己的个数-B串自己匹配自己的个数

而一个串匹配自己的个数就是枚举两个后缀然后对它们lcp求和,即:
\[ \sum_{i=1}^n\sum_{j=i}^nlcp(i,j) \]
也就是后缀排序后每两个后缀之间 \(height\) 的最小值:
\[ \sum_{i=1}^n\sum_{j=i}^n\min\{height[i\cdots j]\} \]
\(height\) 求出来后这就是一个很经典的问题了。

所以在A串和B串间放一个分隔符(比如$) ,后缀排序再算上面那个东西即可。

Code(SAM)

#include 
#include 
#include 
#include 
#include 

typedef long long LL;
typedef unsigned long long uLL;

#define SZ(x) ((int)x.size())
#define ALL(x) (x).begin(), (x).end()
#define MP(x, y) std::make_pair(x, y)
#define DEBUG(...) fprintf(stderr, __VA_ARGS__)
#define GO cerr << "GO" << endl;

using namespace std;

inline void proc_status()
{
    ifstream t("/proc/self/status");
    cerr << string(istreambuf_iterator(t), istreambuf_iterator()) << endl;
}

template inline T read() 
{
    register int x = 0; register int f = 1; register char c;
    while (!isdigit(c = getchar())) if (c == '-') f = -1;
    while (x = (x << 1) + (x << 3) + (c xor 48), isdigit(c = getchar()));
    return x * f;
}

template inline bool chkmin(T &a, T b) { return a > b ? a = b, 1 : 0; }
template inline bool chkmax(T &a, T b) { return a < b ? a = b, 1 : 0; }

const int maxN = (int) 2e5;

namespace SAM
{
    int last, Ncnt, size[maxN * 2];
    LL sum[maxN * 2];

    struct Status
    {
        int len, link;
        int ch[26];
    } st[maxN * 2]; 

    void init()
    {
        last = 0;
        st[0].link = -1;
        st[0].len = 0;
    }

    void insert(char ch)
    {
        int c = ch - 'a';
        int cur = ++Ncnt;
        int p = last;
        st[cur].len = st[p].len + 1;
        while (p != -1 and !st[p].ch[c])
        {
            st[p].ch[c] = cur;
            p = st[p].link;
        }
        if (p == -1)
            st[cur].link = 0;
        else 
        {
            int q = st[p].ch[c];
            if (st[q].len == st[p].len + 1)
                st[cur].link = q;
            else 
            {
                int clone = ++Ncnt;
                st[clone] = st[q];
                st[clone].len = st[p].len + 1;
                while (p != -1 and st[p].ch[c] == q)
                {
                    st[p].ch[c] = clone;
                    p = st[p].link;
                }
                st[q].link = st[cur].link = clone;
            }
        }
        last = cur;
        size[cur] = 1;
    }

    void debug(int x)
    {
        printf("%d link is %d\n", x, st[x].link);
        for (int i = 0; i < 26; ++i)
            if (st[x].ch[i])
                printf("%d to %d %c\n", x, st[x].ch[i], i + 'a');
        puts("-----------");
    }
}
using namespace SAM;

int n, m;
char A[maxN + 2], B[maxN + 2];

void Input() { scanf("%s%s", A + 1, B + 1); }

void Init()
{
    n = strlen(A + 1), m = strlen(B + 1);
    init();
    for (register int i = 1; i <= m; ++i)
        insert(B[i]);

    static int buc[maxN * 2+ 2], rk[maxN * 2 + 2];

    for (register int i = 1; i <= Ncnt; ++i) ++buc[st[i].len];
    for (register int i = 1; i <= Ncnt; ++i) buc[i] += buc[i - 1];
    for (register int i = 1; i <= Ncnt; ++i) rk[buc[st[i].len]--] = i;


    for (register int i = Ncnt; i >= 1; --i)
    {
        int p = rk[i];
        if (p) size[st[p].link] += size[p];
    }
    for (register int i = 1; i <= Ncnt; ++i)
    {
        int p = rk[i];
        if (p) sum[p] = sum[st[p].link] + 1ll * size[p] * (st[p].len - st[st[p].link].len);
    }
}

void Solve()
{
    LL ans = 0;
    int cur = 0, L = 0;
    for (register int i = 1; i <= n; ++i)
    {
        int c = A[i] - 'a';
        while (cur != -1 and !st[cur].ch[c])
        {
            cur = st[cur].link;
            if (cur != -1)
                L = st[cur].len;
        }
        if (cur != -1)
        {
            L++;
            cur = st[cur].ch[c];
            ans += sum[st[cur].link] + (L - st[st[cur].link].len) * size[cur];
        }
        else 
            cur = 0;
    }
    printf("%lld\n", ans);
}

int main() 
{
    Input();
    Init();
    Solve();
    return 0;
}

Code(SA)

#include 
#include 
#include 
#include 
#include 

typedef long long LL;
typedef unsigned long long uLL;

#define SZ(x) ((int)x.size())
#define ALL(x) (x).begin(), (x).end()
#define MP(x, y) std::make_pair(x, y)
#define DEBUG(...) fprintf(stderr, __VA_ARGS__)
#define GO cerr << "GO" << endl;

using namespace std;

inline void proc_status()
{
    ifstream t("/proc/self/status");
    cerr << string(istreambuf_iterator(t), istreambuf_iterator()) << endl;
}

template inline T read() 
{
    register T x(0);
    register char c;
    register int f(1);
    while (!isdigit(c = getchar())) if (c == '-') f = -1;
    while (x = (x << 1) + (x << 3) + (c xor 48), isdigit(c = getchar()));
    return x * f;
}

template inline bool chkmin(T &a, T b) { return a > b ? a = b, 1 : 0; }
template inline bool chkmax(T &a, T b) { return a < b ? a = b, 1 : 0; }

const int maxN = 2e6 + 2;

namespace SA
{
    int ht[maxN + 2], n;
    int tmp[maxN + 2], sa[maxN + 2], rk[maxN + 2], M;

    void Rsort()
    {
        static int buc[maxN + 2];

        fill(buc, buc + 1 + M, 0);
        for (register int i = 1; i <= n; ++i) ++buc[rk[i]];
        for (register int i = 1; i <= M; ++i) buc[i] += buc[i - 1];
        for (register int i = n; i >= 1; --i) sa[buc[rk[tmp[i]]]--] = tmp[i];
    }

    void Build(char str[])
    {//这里记得清空。
        fill(ht + 1, ht + 1 + n, 0);
        fill(sa + 1, sa + 1 + n, 0);
        fill(rk + 1, rk + 1 + n, 0);
        fill(tmp + 1, tmp + 1 + n, 0);
        n = strlen(str + 1), M = 230;
        for (register int i = 1; i <= n; ++i)
            rk[i] = str[i], tmp[i] = i;
        Rsort();
        for (int w = 1, cnt = 0; cnt < n; w <<= 1, M = cnt)
        {
            cnt = 0;
            for (register int i = n - w + 1; i <= n; ++i) tmp[++cnt] = i;
            for (register int i = 1; i <= n; ++i) if (sa[i] > w) tmp[++cnt] = sa[i] - w;
            Rsort(); swap(rk, tmp);
            rk[sa[1]] = cnt = 1;
            for (register int i = 2; i <= n; ++i)
                rk[sa[i]] = (tmp[sa[i]] == tmp[sa[i - 1]] && tmp[sa[i] + w] == tmp[sa[i - 1] + w]) ? cnt : ++cnt;
        }
        for (int i = 1, k = 0; i <= n; ++i)
        {
            if (k) k--;
            int j = sa[rk[i] - 1];
            while (j + k <= n and i + k <= n and str[i + k] == str[j + k]) k++;
            ht[rk[i]] = k;
        }
    }
}

LL solve(int n, int a[])
{
    LL ans(0);
    static int R[maxN + 2], L[maxN + 2];
    static int rk[maxN + 2], buc[maxN + 2];

    fill(buc + 0, buc + n + 1, 0);
    for (register int i = 1; i <= n; ++i) ++buc[a[i]];
    for (register int i = 1; i <= n; ++i) buc[i] += buc[i - 1];
    for (register int i = 1; i <= n; ++i) rk[buc[a[i]]--] = i;

    for (int i = 1; i <= n; ++i) L[i] = i - 1, R[i] = i + 1;
    for (int i = n; i >= 1; --i)
    {
        int p = rk[i];
        ans += (LL) a[p] * (R[p] - p) * (p - L[p]);
        L[R[p]] = L[p];
        R[L[p]] = R[p];
        L[p] = R[p] = 0;
    }
    return ans;
}

LL work(char str[])
{
    SA::Build(str);
    int n = strlen(str + 1);
    return solve(n, SA::ht);
}

char s1[maxN + 2], s2[maxN + 2], s3[maxN << 1];

void Input()
{
    scanf("%s", s1 + 1);
    scanf("%s", s2 + 1);
}

void Solve()
{
    LL ans = 0;
    static char s3[maxN + 2];

    int len1 = strlen(s1 + 1), len2 = strlen(s2 + 1);
    for (register int i = 1; i <= len1; ++i) s3[i] = s1[i];
    s3[len1 + 1] = '$';
    for (register int i = 1; i <= len2; ++i) s3[i + len1 + 1] = s2[i];
    ans += work(s3);
    ans -= work(s1);
    ans -= work(s2);
    printf("%lld\n", ans);
}

int main() 
{
#ifndef ONLINE_JUDGE
    freopen("P3181.in", "r", stdin);
    freopen("P3181.out", "w", stdout);
#endif
    Input();
    Solve();
    return 0;
}

转载于:https://www.cnblogs.com/cnyali-Tea/p/11478131.html

你可能感兴趣的:([HAOI2016]找相同字符)