在应用中,AC自动机大多数是与DP结合起来用的,当然也有其他类型的应用。
最经典的应用之一:
给出一些串,这些串是“病毒串”,问有多少种长度为n且不包含病毒串(或者至少出现一次)的字符串。
这类问题中,病毒串长度一般很小,总长度一般不超过50,而n却很大,一般在10^9以上。
如果只有一个病毒串,那么我们只需要KMP就好了,比如
【BZOJ1009: [HNOI2008]GT考试】
我们先求出A[i][j],表示病毒串从i这个前缀添加一个字符,变为j这个前缀的方案数,这个可以先求出next数组,然后用类似一个匹配去求。
设dp[i][j]表示长度为i,末尾的j个字符为病毒串的前缀的方案数,那么有转移
dp[i][j] = ∑(dp[i][k] * A[k][j]),0 <= k < len
这是一个线性递推,我们可以用矩阵快速幂加速,转移矩阵即A数组。
答案为∑dp[len][i],0 <= i < len
直接上代码。
/* Pigonometry */ #include <cstdio> #include <cstring> #include <algorithm> using namespace std; const int maxn = 25; int n, m, p, s[maxn], fail[maxn]; struct _mat { int num[maxn][maxn]; } E, trans; inline int iread() { int f = 1, x = 0; char ch = getchar(); for(; ch < '0' || ch > '9'; ch = getchar()) f = ch == '-' ? -1 : 1; for(; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0'; return f * x; } inline _mat mul(_mat &A, _mat &B) { _mat C; for(int i = 0; i < m; i++) for(int j = 0; j < m; j++) { C.num[i][j] = 0; for(int k = 0; k < m; k++) C.num[i][j] = (C.num[i][j] + A.num[i][k] * B.num[k][j]) % p; } return C; } inline _mat qpow(_mat &A, int n) { _mat ans = E; for(_mat t = A; n; n >>= 1, t = mul(t, t)) if(n & 1) ans = mul(ans, t); return ans; } char str[maxn]; int main() { n = iread(); m = iread(); p = iread(); scanf("%s", str + 1); for(int i = 1; i <= m; i++) s[i] = str[i] - '0'; for(int i = 2, j = 0; i <= m; fail[i++] = j) { for(; j != 0 && s[j + 1] != s[i]; j = fail[j]); if(s[j + 1] == s[i]) j++; } for(int i = 0; i < m; i++) for(int j = 0; j <= 9; j++) { int k = i; for(; k != 0 && s[k + 1] != j; k = fail[k]); if(s[k + 1] == j) k++; trans.num[i][k] = (trans.num[i][k] + 1) % p; } for(int i = 0; i < m; i++) E.num[i][i] = 1; _mat res = qpow(trans, n); int ans = 0; for(int i = 0; i < m; i++) ans = (ans + res.num[0][i]) % p; printf("%d\n", ans); return 0; }
如果有多个病毒串,这时候就要用AC自动机了,思路类似。
比如【BZOJ1030: [JSOI2007]文本生成器】
这个题数据范围较小,可以不用矩阵快速幂。
题目求的是至少包含一个单词的方案数,我们转化为 总方案数 - 一个单词都不包含的方案数。
前者是26^m,一个快速幂就好了,后者用dp求。
设dp[i][j],表示字符串长度为i时,在AC自动机上的第j个节点,不包含病毒串的方案数,那么有转移
dp[i][son[j][k]] += dp[i - 1][j],0 <= k < 26,且son[j][k]不为病毒串的结尾(AC自动机插入和求fail数组时可以预处理出)。
答案为26^m - ∑dp[m][i]
/* Pigonometry */ #include <cstdio> #include <cstring> #define cls(a, x) memset(a, x, sizeof(a)) using namespace std; const int maxn = 6005, maxm = 105, p = 10007, maxq = 10000; int dp[maxm][maxn], q[maxq]; struct _acm { int son[maxn][26], fail[maxn], acmcnt; bool flag[maxn]; void init() { cls(son, 0); for(int i = 0; i < maxn; i++) fail[i] = flag[i] = 0; acmcnt = 0; } void insert(char *s) { int now = 0, len = strlen(s); for(int i = 0; i < len; i++) { int &pos = son[now][s[i] - 'A']; if(!pos) pos = ++acmcnt; now = pos; } flag[now] = 1; } void getfail() { int h = 0, t = 0; for(int i = 0; i < 26; i++) if(son[0][i]) q[t++] = son[0][i]; while(h != t) { int u = q[h++]; for(int i = 0; i < 26; i++) if(!son[u][i]) son[u][i] = son[fail[u]][i]; else { fail[q[t++] = son[u][i]] = son[fail[u]][i]; flag[son[u][i]] |= flag[fail[son[u][i]]]; } } } } acm; inline int qpow(int a, int n) { int ans = 1; for(int t = a; n; n >>= 1, t = t * t % p) if(n & 1) ans = ans * t % p; return ans; } char str[maxm]; int main() { int n, m; scanf("%d%d", &n, &m); for(int i = 1; i <= n; i++) { scanf("%s", str); acm.insert(str); } acm.getfail(); dp[0][0] = 1; for(int i = 1; i <= m; i++) for(int j = 0; j <= acm.acmcnt; j++) for(int k = 0; k < 26; k++) if(!acm.flag[acm.son[j][k]]) dp[i][acm.son[j][k]] = (dp[i][acm.son[j][k]] + dp[i - 1][j]) % p; int ans = qpow(26, m); for(int i = 0; i <= acm.acmcnt; i++) ans = (ans - dp[m][i] + p) % p; printf("%d\n", ans); return 0; }
上个题如果m非常大,那么就要用矩阵快速幂了。换种思路(其实还是线性递推)。
比如这个题【POJ2778: DNA Sequence】
设A[i][j]表示从AC自动机上的第i个节点添加一个字符到第j个节点的方案数。
A数组可以枚举i,然后枚举i的儿子求出。
然后对A跑矩阵快速幂就好了。
/* Pigonometry */ #include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; typedef long long LL; const int maxn = 150, maxnode = maxn, maxd = 4, maxq = 10000, p = 100000; int n, m, id[26], q[maxq], size; struct _acm { int son[maxnode][maxd], acmcnt, fail[maxnode]; bool flag[maxnode]; void init() { memset(son, 0, sizeof(son)); acmcnt = 0; memset(fail, 0, sizeof(fail)); memset(flag, 0, sizeof(flag)); } void insert(string s) { int now = 0, len = s.size(); for(int i = 0; i < len; i++) { int &pos = son[now][id[s[i] - 'A']]; if(!pos) pos = ++acmcnt; now = pos; } flag[now]++; } void getfail() { int h = 0, t = 0; for(int i = 0; i < 4; i++) if(son[0][i]) q[t++] = son[0][i]; while(h != t) { int u = q[h++]; for(int i = 0; i < 4; i++) if(!son[u][i]) son[u][i] = son[fail[u]][i]; else { fail[q[t++] = son[u][i]] = son[fail[u]][i]; flag[son[u][i]] |= flag[fail[son[u][i]]]; } } } } acm; struct _matrix { int num[maxn][maxn]; } trans, E; inline _matrix matmul(_matrix A, _matrix B) { _matrix ans; for(int i = 0; i < size; i++) for(int j = 0; j < size; j++) { ans.num[i][j] = 0; for(int k = 0; k < size; k++) ans.num[i][j] = (ans.num[i][j] + ((LL)A.num[i][k] * B.num[k][j]) % p) % p; } return ans; } _matrix matqpow(_matrix A, int n) { _matrix s = E; for(_matrix t = A; n; n >>= 1, t = matmul(t, t)) if(n & 1) s = matmul(s, t); return s; } int main() { ios::sync_with_stdio(false); cin >> m >> n; acm.init(); id['A' - 'A'] = 0; id['C' - 'A'] = 1; id['G' - 'A'] = 2; id['T' - 'A'] = 3; for(int i = 1; i <= m; i++) { string str; cin >> str; acm.insert(str); } acm.getfail(); size = acm.acmcnt + 1; for(int i = 0; i < size; i++) E.num[i][i] = 1; for(int i = 0; i < size; i++) if(!acm.flag[i]) for(int j = 0; j < 4; j++) if(!acm.flag[acm.son[i][j]]) trans.num[i][acm.son[i][j]]++; _matrix res = matqpow(trans, n); int ans = 0; for(int i = 0; i < size; i++) ans = (ans + res.num[0][i]) % p; printf("%d\n", ans); return 0; }
另外还有一个比较有趣的拓展,这个
【HDU2243: 考研路茫茫——单词情结】
这个题求长度不小于m的字符串的方案数。
即求A^1 + A^2 + A^3 + ... + A^m
可以构造一个分块矩阵,长这样:
A E
0 E
其中E为单位矩阵,对这个分块矩阵跑快速幂就好了。
/* Pigonometry */ #include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; typedef unsigned long long ULL; const int maxn = 105, maxnode = 40, maxd = 26, maxq = maxn; ULL n, l; int size, q[maxq]; struct _acm { int son[maxnode][maxd], acmcnt, fail[maxnode]; bool flag[maxnode]; void init() { memset(son, 0, sizeof(son)); acmcnt = 0; memset(fail, 0, sizeof(fail)); memset(flag, 0, sizeof(flag)); } void insert(string s) { int now = 0, len = s.size(); for(int i = 0; i < len; i++) { int &pos = son[now][s[i] - 'a']; if(!pos) pos = ++acmcnt; now = pos; } flag[now]++; } void getfail() { int h = 0, t = 0; for(int i = 0; i < 26; i++) if(son[0][i]) q[t++] = son[0][i]; while(h != t) { int u = q[h++]; for(int i = 0; i < 26; i++) if(!son[u][i]) son[u][i] = son[fail[u]][i]; else { fail[q[t++] = son[u][i]] = son[fail[u]][i]; flag[son[u][i]] |= flag[fail[son[u][i]]]; } } } } acm; struct _matrix { ULL num[maxn][maxn]; } trans, E, two; _matrix matmul(_matrix A, _matrix B) { _matrix ans; for(int i = 0; i < size; i++) for(int j = 0; j < size; j++) { ans.num[i][j] = 0; for(int k = 0; k < size; k++) ans.num[i][j] += A.num[i][k] * B.num[k][j]; } return ans; } _matrix matqpow(_matrix A, ULL n) { _matrix s = E; for(_matrix t = A; n; n >>= 1, t = matmul(t, t)) if(n & 1) s = matmul(s, t); return s; } int main() { ios::sync_with_stdio(false); for(int i = 0; i < maxn; i++) E.num[i][i] = 1; while(cin >> n >> l) { acm.init(); for(int i = 1; i <= n; i++) { string str; cin >> str; acm.insert(str); } acm.getfail(); memset(trans.num, 0, sizeof(trans.num)); for(int i = 0; i <= acm.acmcnt; i++) if(!acm.flag[i]) for(int j = 0; j < 26; j++) if(!acm.flag[acm.son[i][j]]) trans.num[i][acm.son[i][j]]++; for(int i = 1; i <= acm.acmcnt + 1; i++) trans.num[i - 1][acm.acmcnt + i] = trans.num[acm.acmcnt + i][acm.acmcnt + i] = 1; size = (acm.acmcnt << 1) + 2; _matrix res = matqpow(trans, l); ULL ans = 0; for(int i = 0; i < size; i++) if(!acm.flag[i]) ans += res.num[0][i]; memset(two.num, 0, sizeof(two.num)); two.num[0][0] = 26; two.num[0][1] = 1; two.num[1][0] = 0; two.num[1][1] = 1; size = 2; _matrix ret = matqpow(two, l); ULL tot = ret.num[0][0] + ret.num[0][1]; cout << tot - ans << endl; } return 0; }
还有一些其他类型的DP,比如
【Codeforces86C: Genetic engineering】
有m个模板串,要求字符串中每一个字符都至少被一个模板串覆盖,求长度为n的字符串个数。
设cover[i]表示AC自动机上第i个节点,以这个节点为结尾的模板串的长度的最大值。
设dp[i][j][k]表示长度为i,在AC自动机上第j个节点,结尾的k个字符未匹配的方案数。有转移
dp[i + 1][u][0] += dp[i][j][k],u为i的儿子,且cover[u] >= k + 1
dp[i + 1][u][k + 1] += dp[i][j][k],u为i的儿子,且cover[u] < k + 1
/* Pigonometry */ #include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; const int maxn = 1005, maxd = 5, maxnode = maxn, maxq = maxnode, p = 1000000009; int n, m, id[30], dp[maxn][105][15], q[maxq]; struct _acm { int son[maxnode][maxd], acmcnt, fail[maxnode], cover[maxnode]; void init() { memset(son, 0, sizeof(son)); acmcnt = 0; memset(fail, 0, sizeof(fail)); memset(cover, 0, sizeof(cover)); } void insert(string s) { int now = 0, len = s.size(); for(int i = 0; i < len; i++) { int &pos = son[now][id[s[i] - 'A']]; if(!pos) pos = ++acmcnt; now = pos; } cover[now] = len; } void getfail() { int h = 0, t = 0; for(int i = 0; i < 4; i++) if(son[0][i]) q[t++] = son[0][i]; while(h != t) { int u = q[h++]; for(int i = 0; i < 4; i++) { int &pos = son[u][i]; if(!pos) pos = son[fail[u]][i]; else { fail[q[t++] = pos] = son[fail[u]][i]; cover[pos] = max(cover[pos], cover[fail[pos]]); } } } } } acm; int main() { ios::sync_with_stdio(false); cin >> n >> m; acm.init(); id['A' - 'A'] = 0; id['C' - 'A'] = 1; id['G' - 'A'] = 2; id['T' - 'A'] = 3; for(int i = 1; i <= m; i++) { string str; cin >> str; acm.insert(str); } acm.getfail(); dp[0][0][0] = 1; for(int i = 0; i < n; i++) for(int j = 0; j <= acm.acmcnt; j++) for(int k = 0; k <= 10; k++) if(dp[i][j][k]) for(int l = 0; l < 4; l++) { int u = acm.son[j][l]; if(acm.cover[u] >= k + 1) dp[i + 1][u][0] = (dp[i + 1][u][0] + dp[i][j][k]) % p; else dp[i + 1][u][k + 1] = (dp[i + 1][u][k + 1] + dp[i][j][k]) % p; } int ans = 0; for(int i = 0; i <= acm.acmcnt; i++) ans = (ans + dp[n][i][0]) % p; printf("%d\n", ans); return 0; }
【POJ3691: DNA repair】
给出n个病毒串,和一个字符串,问至少修改多少个字符,使得这个字符串不包含病毒串。
设dp[i][j]表示长度为i,在AC自动机上第j个节点,至少修改了多少字符。
枚举j的儿子u,如果u和原字符串的字符不相同,那么就要修改。
dp[i][u] = min(dp[i][u], dp[i - 1][j] + [k与原字符串的字符不相同])
/* Pigonometry */ #include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; const int maxn = 1005, maxd = 4, maxnode = maxn, maxq = maxnode; int n, id[26], dp[maxn][maxnode], q[maxq]; struct _acm { int son[maxnode][maxd], acmcnt, fail[maxnode]; bool flag[maxnode]; void init() { memset(son, 0, sizeof(son)); acmcnt = 0; memset(fail, 0, sizeof(fail)); memset(flag, 0, sizeof(flag)); } void insert(string s) { int now = 0, len = s.size(); for(int i = 0; i < len; i++) { int &pos = son[now][id[s[i] - 'A']]; if(!pos) pos = ++acmcnt; now = pos; } flag[now]++; } void getfail() { int h = 0, t = 0; for(int i = 0; i < 4; i++) if(son[0][i]) q[t++] = son[0][i]; while(h != t) { int u = q[h++]; for(int i = 0; i < 4; i++) { int &v = son[u][i]; if(!v) v = son[fail[u]][i]; else { fail[q[t++] = v] = son[fail[u]][i]; flag[v] |= flag[fail[v]]; } } } } int getans(string s) { memset(dp, 0x3f, sizeof(dp)); dp[0][0] = 0; int len = s.size(); for(int i = 1; i <= len; i++) for(int j = 0; j <= acmcnt; j++) if(dp[i - 1][j] != 0x3f3f3f3f) for(int k = 0; k < 4; k++) { int u = son[j][k]; if(!flag[u]) dp[i][u] = min(dp[i][u], dp[i - 1][j] + (k != id[s[i - 1] - 'A'])); } int ans = 0x3f3f3f3f; for(int i = 0; i <= acmcnt; i++) if(!flag[i]) ans = min(ans, dp[len][i]); if(ans == 0x3f3f3f3f) ans = -1; return ans; } } acm; int main() { ios::sync_with_stdio(false); id['A' - 'A'] = 0; id['C' - 'A'] = 1; id['G' - 'A'] = 2; id['T' - 'A'] = 3; for(int cas = 1; ; cas++) { cin >> n; if(!n) break; acm.init(); for(int i = 1; i <= n; i++) { string str; cin >> str; acm.insert(str); } acm.getfail(); string str; cin >> str; printf("Case %d: %d\n", cas, acm.getans(str)); } return 0; }