Description
给定长度分别为 \(n\), \(m\) 的两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。
\(n,m \le 2\times 10^5\)
Solution
\(yyt\)的题,考试时并不会后缀自动机,于是只能对一个串把所有后缀插入\(AC\)自动机,另一个串在上面跑,每次到一个节点就暴力跳\(fail\),记\(cnt[x]\)为自动机\(x\)点被多少个后缀经过,把\(fail\)的\(cnt\)加到答案里。
这样做正确性在于对B串所有后缀放入AC自动机后,A串在上面跑,到达一个点x表示A串的一个前缀的后缀和B串若干个后缀的前缀匹配上了(即若干个子串),由于是匹配的是最长的一段所以每次还要跳\(fail\),把\(fail\)的答案累计上,下面是核心代码:
for (int i = 1; i <= n; ++i)
{
p = ch[p][A[i] - 'a'];
for (int t = p; t; t = fail[t])
ans += cnt[t];
}
然而这样空间爆炸(时间也爆炸,但在超时前空间已经爆了),于是就用后缀自动记来接受所有后缀。
后缀自动机的做法也是依赖上面的原理的,对B串建好SAM后A串仍然在B串的SAM上跑,到达一个点x代表A的一段前缀的后缀和B的若干个子串匹配上了我们记录当前的匹配的串长L,那么我们需要统计x接受了多少长度小于等于L的子串,那么就是 parent树上从x的父亲到根路径上所有点的接受串种类数×接受串的个数(endpos集合大小)之和+(L-x父亲的maxlen)×点x接受串个数(注意由于L可能会把x接受的后缀劈开,所以要单独加一下)。
前者预处理即可。
再提供一种 \(SA\) 的做法:
不难发现答案=A+B串自己匹配自己个数-A串自己匹配自己的个数-B串自己匹配自己的个数
。
而一个串匹配自己的个数就是枚举两个后缀然后对它们lcp求和,即:
\[ \sum_{i=1}^n\sum_{j=i}^nlcp(i,j) \]
也就是后缀排序后每两个后缀之间 \(height\) 的最小值:
\[ \sum_{i=1}^n\sum_{j=i}^n\min\{height[i\cdots j]\} \]
把 \(height\) 求出来后这就是一个很经典的问题了。
所以在A串和B串间放一个分隔符(比如$
) ,后缀排序再算上面那个东西即可。
Code(SAM)
#include
#include
#include
#include
#include
typedef long long LL;
typedef unsigned long long uLL;
#define SZ(x) ((int)x.size())
#define ALL(x) (x).begin(), (x).end()
#define MP(x, y) std::make_pair(x, y)
#define DEBUG(...) fprintf(stderr, __VA_ARGS__)
#define GO cerr << "GO" << endl;
using namespace std;
inline void proc_status()
{
ifstream t("/proc/self/status");
cerr << string(istreambuf_iterator(t), istreambuf_iterator()) << endl;
}
template inline T read()
{
register int x = 0; register int f = 1; register char c;
while (!isdigit(c = getchar())) if (c == '-') f = -1;
while (x = (x << 1) + (x << 3) + (c xor 48), isdigit(c = getchar()));
return x * f;
}
template inline bool chkmin(T &a, T b) { return a > b ? a = b, 1 : 0; }
template inline bool chkmax(T &a, T b) { return a < b ? a = b, 1 : 0; }
const int maxN = (int) 2e5;
namespace SAM
{
int last, Ncnt, size[maxN * 2];
LL sum[maxN * 2];
struct Status
{
int len, link;
int ch[26];
} st[maxN * 2];
void init()
{
last = 0;
st[0].link = -1;
st[0].len = 0;
}
void insert(char ch)
{
int c = ch - 'a';
int cur = ++Ncnt;
int p = last;
st[cur].len = st[p].len + 1;
while (p != -1 and !st[p].ch[c])
{
st[p].ch[c] = cur;
p = st[p].link;
}
if (p == -1)
st[cur].link = 0;
else
{
int q = st[p].ch[c];
if (st[q].len == st[p].len + 1)
st[cur].link = q;
else
{
int clone = ++Ncnt;
st[clone] = st[q];
st[clone].len = st[p].len + 1;
while (p != -1 and st[p].ch[c] == q)
{
st[p].ch[c] = clone;
p = st[p].link;
}
st[q].link = st[cur].link = clone;
}
}
last = cur;
size[cur] = 1;
}
void debug(int x)
{
printf("%d link is %d\n", x, st[x].link);
for (int i = 0; i < 26; ++i)
if (st[x].ch[i])
printf("%d to %d %c\n", x, st[x].ch[i], i + 'a');
puts("-----------");
}
}
using namespace SAM;
int n, m;
char A[maxN + 2], B[maxN + 2];
void Input() { scanf("%s%s", A + 1, B + 1); }
void Init()
{
n = strlen(A + 1), m = strlen(B + 1);
init();
for (register int i = 1; i <= m; ++i)
insert(B[i]);
static int buc[maxN * 2+ 2], rk[maxN * 2 + 2];
for (register int i = 1; i <= Ncnt; ++i) ++buc[st[i].len];
for (register int i = 1; i <= Ncnt; ++i) buc[i] += buc[i - 1];
for (register int i = 1; i <= Ncnt; ++i) rk[buc[st[i].len]--] = i;
for (register int i = Ncnt; i >= 1; --i)
{
int p = rk[i];
if (p) size[st[p].link] += size[p];
}
for (register int i = 1; i <= Ncnt; ++i)
{
int p = rk[i];
if (p) sum[p] = sum[st[p].link] + 1ll * size[p] * (st[p].len - st[st[p].link].len);
}
}
void Solve()
{
LL ans = 0;
int cur = 0, L = 0;
for (register int i = 1; i <= n; ++i)
{
int c = A[i] - 'a';
while (cur != -1 and !st[cur].ch[c])
{
cur = st[cur].link;
if (cur != -1)
L = st[cur].len;
}
if (cur != -1)
{
L++;
cur = st[cur].ch[c];
ans += sum[st[cur].link] + (L - st[st[cur].link].len) * size[cur];
}
else
cur = 0;
}
printf("%lld\n", ans);
}
int main()
{
Input();
Init();
Solve();
return 0;
}
Code(SA)
#include
#include
#include
#include
#include
typedef long long LL;
typedef unsigned long long uLL;
#define SZ(x) ((int)x.size())
#define ALL(x) (x).begin(), (x).end()
#define MP(x, y) std::make_pair(x, y)
#define DEBUG(...) fprintf(stderr, __VA_ARGS__)
#define GO cerr << "GO" << endl;
using namespace std;
inline void proc_status()
{
ifstream t("/proc/self/status");
cerr << string(istreambuf_iterator(t), istreambuf_iterator()) << endl;
}
template inline T read()
{
register T x(0);
register char c;
register int f(1);
while (!isdigit(c = getchar())) if (c == '-') f = -1;
while (x = (x << 1) + (x << 3) + (c xor 48), isdigit(c = getchar()));
return x * f;
}
template inline bool chkmin(T &a, T b) { return a > b ? a = b, 1 : 0; }
template inline bool chkmax(T &a, T b) { return a < b ? a = b, 1 : 0; }
const int maxN = 2e6 + 2;
namespace SA
{
int ht[maxN + 2], n;
int tmp[maxN + 2], sa[maxN + 2], rk[maxN + 2], M;
void Rsort()
{
static int buc[maxN + 2];
fill(buc, buc + 1 + M, 0);
for (register int i = 1; i <= n; ++i) ++buc[rk[i]];
for (register int i = 1; i <= M; ++i) buc[i] += buc[i - 1];
for (register int i = n; i >= 1; --i) sa[buc[rk[tmp[i]]]--] = tmp[i];
}
void Build(char str[])
{//这里记得清空。
fill(ht + 1, ht + 1 + n, 0);
fill(sa + 1, sa + 1 + n, 0);
fill(rk + 1, rk + 1 + n, 0);
fill(tmp + 1, tmp + 1 + n, 0);
n = strlen(str + 1), M = 230;
for (register int i = 1; i <= n; ++i)
rk[i] = str[i], tmp[i] = i;
Rsort();
for (int w = 1, cnt = 0; cnt < n; w <<= 1, M = cnt)
{
cnt = 0;
for (register int i = n - w + 1; i <= n; ++i) tmp[++cnt] = i;
for (register int i = 1; i <= n; ++i) if (sa[i] > w) tmp[++cnt] = sa[i] - w;
Rsort(); swap(rk, tmp);
rk[sa[1]] = cnt = 1;
for (register int i = 2; i <= n; ++i)
rk[sa[i]] = (tmp[sa[i]] == tmp[sa[i - 1]] && tmp[sa[i] + w] == tmp[sa[i - 1] + w]) ? cnt : ++cnt;
}
for (int i = 1, k = 0; i <= n; ++i)
{
if (k) k--;
int j = sa[rk[i] - 1];
while (j + k <= n and i + k <= n and str[i + k] == str[j + k]) k++;
ht[rk[i]] = k;
}
}
}
LL solve(int n, int a[])
{
LL ans(0);
static int R[maxN + 2], L[maxN + 2];
static int rk[maxN + 2], buc[maxN + 2];
fill(buc + 0, buc + n + 1, 0);
for (register int i = 1; i <= n; ++i) ++buc[a[i]];
for (register int i = 1; i <= n; ++i) buc[i] += buc[i - 1];
for (register int i = 1; i <= n; ++i) rk[buc[a[i]]--] = i;
for (int i = 1; i <= n; ++i) L[i] = i - 1, R[i] = i + 1;
for (int i = n; i >= 1; --i)
{
int p = rk[i];
ans += (LL) a[p] * (R[p] - p) * (p - L[p]);
L[R[p]] = L[p];
R[L[p]] = R[p];
L[p] = R[p] = 0;
}
return ans;
}
LL work(char str[])
{
SA::Build(str);
int n = strlen(str + 1);
return solve(n, SA::ht);
}
char s1[maxN + 2], s2[maxN + 2], s3[maxN << 1];
void Input()
{
scanf("%s", s1 + 1);
scanf("%s", s2 + 1);
}
void Solve()
{
LL ans = 0;
static char s3[maxN + 2];
int len1 = strlen(s1 + 1), len2 = strlen(s2 + 1);
for (register int i = 1; i <= len1; ++i) s3[i] = s1[i];
s3[len1 + 1] = '$';
for (register int i = 1; i <= len2; ++i) s3[i + len1 + 1] = s2[i];
ans += work(s3);
ans -= work(s1);
ans -= work(s2);
printf("%lld\n", ans);
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("P3181.in", "r", stdin);
freopen("P3181.out", "w", stdout);
#endif
Input();
Solve();
return 0;
}