AC自动机模板及基础例题小结

AC自动机(Aho-Corasick Automation)用于解决多模式串匹配主串的问题
给所有模式串写一个Trie,在Trie上跑KMP,其中KMP的next数组变成了AC自动机的Fail指针
计算fail和计算next一样,用dp,只不过这里是树上dp
原理不再赘述,上模板

#include 
#include 
#include 
#include 
using namespace std;
class ACautomation {
private:
    int cnt, root;
    int **child, *sta, *fail;
    bool *flag;
public:
    ACautomation(int tot) {
        cnt = root = 1;
        sta = new int[tot+1];
        memset(sta, 0, sizeof(int)*(tot+1));
        fail = new int[tot+1];
        memset(fail, 0, sizeof(int)*(tot+1));
        flag = new bool[tot+1];
        memset(flag, 0, sizeof(bool)*(tot+1));
        child = new int*[tot+1];
        for (int i = 0; i <= tot; i++) {
            child[i] = new int[26];
            memset(child[i], 0, sizeof(int)*26);
        }
        for (int i = 0; i < 26; i++)    child[0][i] = 1;
    }
    ~ACautomation() {
        delete[] sta;
        delete[] fail;
        delete[] flag;
        for (int i = 0; i <= cnt; i++)
            delete[] child[i];
        delete[] child;
    }
    void insert(string& s) {
        int cur = 1, i;
        for (i = 0; i < s.length(); cur = child[cur][s[i++]-'a'])
            if (!child[cur][s[i]-'a'])  child[cur][s[i]-'a'] = ++cnt;
        sta[cur]++;
    }
    void setFail() {
        queue <int> que;
        que.push(1);
        while (!que.empty()) {
            int u = que.front();
            for (int i = 0; i < 26; i++) {
                if (child[u][i]) {
                    fail[child[u][i]] = child[fail[u]][i];
                    que.push(child[u][i]);
                } else {
                    child[u][i] = child[fail[u]][i];
                }
            }
            que.pop();
        }
    }
    int query(string& s) {
        int ret = 0, cur = 1, index;
        for (int i = 0; i < s.length(); i++) {
            flag[cur] = true;
            index = s[i]-'a';
            while (!child[cur][index])  cur = fail[cur];
            cur = child[cur][index];
            if (flag[cur])  continue;
            for (int j = cur; j; j = fail[j]) {
                ret += sta[j];
                sta[j] = 0;
            }
        }
        return ret;
    }
};
int main() {
    int n;
    cin >> n;
    ACautomation T(n*50);
    for (int i = 0; i < n; i++) {
        string s;
        cin >> s;
        T.insert(s);
    }
    T.setFail();
    string str;
    cin >> str;
    cout << T.query(str);
    return 0;
}

这是本人第一次写AC自动机封装的代码,很拙劣。由于我写Trie树一直把根节点设为1号,所以想出把0号节点所有儿子都指向1,这样如果fail指回到0的时候,child[u][i]都会将其重新指向根节点1,这样思考会简单一些。但是有时候要TLE

我做的第一道AC自动机题目是HDU2222,这是裸的模板,数据也很水

【HDU2222】 Keywords Search
Time Limit: 2000/1000 MS (Java/Others)
Memory Limit: 131072/131072 K (Java/Others)

Problem Description
In the modern time, Search engine came into the life of everybody like Google, Baidu, etc.
Wiskey also wants to bring this feature to his image retrieval system.
Every image have a long description, when users type some keywords to find the image, the system will match the keywords with description of image and show the image which the most keywords be matched.
To simplify the problem, giving you a description of image, and some keywords, you should tell me how many keywords will be match.

Input
First line will contain one integer means how many cases will follow by.
Each case will contain two integers N means the number of keywords and N keywords follow. (N <= 10000)
Each keyword will only contains characters ‘a’-‘z’, and the length will be not longer than 50.
The last line is the description, and the length will be not longer than 1000000.
Output
Print how many keywords are contained in the description.

Sample Input
1
5
she
he
say
shr
her
yasherhs
Sample Output
3

AC代码:

#include 
#include 
#include 
#include 
#define MAX_LETTER 500000
using namespace std;
int trie[MAX_LETTER+5][26], tag[MAX_LETTER+5], cnt;
int fail[MAX_LETTER+5], flag[MAX_LETTER+5], que[MAX_LETTER+5];
char s[55], str[1000005];
void insert() {
    scanf("%s", s);
    int cur, i, l = strlen(s);
    for (cur = 1, i = 0; i < l; cur = trie[cur][s[i++]-'a'])
        if (!trie[cur][s[i]-'a'])   trie[cur][s[i]-'a'] = ++cnt;
    tag[cur]++;
}
void BFS() {
    int head = 0, tail = 1;
    que[head] = 1;
    while (head < tail) {
        int u = que[head++];
        for (int i = 0; i < 26; i++) {
            if (trie[u][i]) {
                que[tail++] = trie[u][i];
                fail[trie[u][i]] = trie[fail[u]][i];
            } else {
                trie[u][i] = trie[fail[u]][i];
            }
        }
    }
}
void query() {
    int ret = 0, cur = 1, index;
    int len = strlen(str);
    for (int i = 0; i < len; i++) {
        flag[cur] = 1;
        index = str[i]-'a';
        while (!trie[cur][index])   cur = fail[cur];
        cur = trie[cur][index];
        if (flag[cur])  continue;
        for (int j = cur; j; j = fail[j]) {
            ret += tag[j];
            tag[j] = 0;
        }
    }
    printf("%d\n", ret);
}
int main() {
    int T, n;
    for (int i = 0; i < 26; i++)    trie[0][i] = 1;
    scanf("%d", &T);
    while (T--) {
        cnt = 1;
        scanf("%d", &n);
        for (int i = 0; i < n; i++) insert();
        BFS();
        scanf("%s", str);
        query();
        for (int i = 1; i <= cnt; i++) {
            tag[i] = fail[i] = flag[i] = 0;
            for (int j = 0; j < 26; j++)
                trie[i][j] = 0;
        }
    }
    return 0;
}

HDU3065和这道题很像,也是几乎裸的模板

【HDU3065】病毒侵袭持续中
Time Limit: 2000/1000 MS (Java/Others)
Memory Limit: 32768/32768 K (Java/Others)

Problem Description
小t非常感谢大家帮忙解决了他的上一个问题。然而病毒侵袭持续中。在小t的不懈努力下,他发现了网路中的“万恶之源”。这是一个庞大的病毒网站,他有着好多好多的病毒,但是这个网站包含的病毒很奇怪,这些病毒的特征码很短,而且只包含“英文大写字符”。当然小t好想好想为民除害,但是小t从来不打没有准备的战争。知己知彼,百战不殆,小t首先要做的是知道这个病毒网站特征:包含多少不同的病毒,每种病毒出现了多少次。大家能再帮帮他吗?

Input
第一行,一个整数N(1<=N<=1000),表示病毒特征码的个数。
接下来N行,每行表示一个病毒特征码,特征码字符串长度在1—50之间,并且只包含“英文大写字符”。任意两个病毒特征码,不会完全相同。
在这之后一行,表示“万恶之源”网站源码,源码字符串长度在2000000之内。字符串中字符都是ASCII码可见字符(不包括回车)。
Output
按以下格式每行一个,输出每个病毒出现次数。未出现的病毒不需要输出。
病毒特征码: 出现次数
冒号后有一个空格,按病毒特征码的输入顺序进行输出。

Sample Input
3
AA
BB
CC
ooxxCC%dAAAoen….END
Sample Output
AA: 2
CC: 1

Hint
题目描述中没有被提及的所有情况都应该进行考虑。比如两个病毒特征码可能有相互包含或者有重叠的特征码段。
计数策略也可一定程度上从Sample中推测。

AC代码:

#include 
#include 
#include 
#include 
#define MAX_N 1000
#define MAX_LETTER 100000
#define DICNUM 26
using namespace std;
int n, cnt;
int child[MAX_LETTER+5][DICNUM], id[MAX_LETTER+5], fail[MAX_LETTER+5], ans[MAX_N+5];
int que[MAX_LETTER+5];
char s[2000005], dic[1005][55];
void init() {
    cnt = 1;
    memset(child, 0, sizeof(child));
    for (int i = 0; i < DICNUM; i++)    child[0][i] = 1;
    memset(id, 0, sizeof(id));
    memset(fail, 0, sizeof(fail));
    memset(ans, 0, sizeof(ans));
}
void insert() {
    for (int j = 1; j <= n; j++) {
        int cur = 1, l = strlen(dic[j]);
        for (int i = 0; i < l; cur = child[cur][dic[j][i++]-'A'])
            if (!child[cur][dic[j][i]-'A']) child[cur][dic[j][i]-'A'] = ++cnt;
        id[cur] = j;
    }
}
void setFail() {
    int head = 0, tail = 1;
    que[head] = 1;
    while (head < tail) {
        int u = que[head++];
        for (int i = 0; i < DICNUM; i++) {
            if (child[u][i]) {
                que[tail++] = child[u][i];
                int cur = fail[u];
                while (!child[cur][i])  cur = fail[cur];
                fail[child[u][i]] = child[cur][i];
            } else {
                child[u][i] = child[fail[u]][i];
            }
        }
    }
}
void query() {
    int cur = 1, index, l = strlen(s);
    for (int i = 0; i < l; i++) {
        index = s[i]-'A';
        if (index < 0 || index > 25) {
            cur = 1;
        } else {
            while (!child[cur][index])  cur = fail[cur];
            cur = child[cur][index];
        }
        for (int j = cur; j; j = fail[j])
            if (id[j])
                ans[id[j]]++;
    }
}
int main() {
    while (scanf("%d", &n) != EOF) {
        init();
        for (int i = 1; i <= n; i++)    scanf("%s", dic[i]);
        insert();
        setFail();
        scanf("%s", s);
        query();
        for (int i = 1; i <= n; i++)
            if (ans[i])
                printf("%s: %d\n", dic[i], ans[i]);
    }
    return 0;
}

这两道题用我拙劣的写法还可以过,但是HDU2896就过不了

【HDU2896】病毒侵袭
Time Limit: 2000/1000 MS (Java/Others)
Memory Limit: 32768/32768 K (Java/Others)

Problem Description
当太阳的光辉逐渐被月亮遮蔽,世界失去了光明,大地迎来最黑暗的时刻。。。。在这样的时刻,人们却异常兴奋——我们能在有生之年看到500年一遇的世界奇观,那是多么幸福的事儿啊~~
但网路上总有那么些网站,开始借着民众的好奇心,打着介绍日食的旗号,大肆传播病毒。小t不幸成为受害者之一。小t如此生气,他决定要把世界上所有带病毒的网站都找出来。当然,谁都知道这是不可能的。小t却执意要完成这不能的任务,他说:“子子孙孙无穷匮也!”(愚公后继有人了)。
万事开头难,小t收集了好多病毒的特征码,又收集了一批诡异网站的源码,他想知道这些网站中哪些是有病毒的,又是带了怎样的病毒呢?顺便还想知道他到底收集了多少带病毒的网站。这时候他却不知道何从下手了。所以想请大家帮帮忙。小t又是个急性子哦,所以解决问题越快越好哦~~

Input
第一行,一个整数N(1<=N<=500),表示病毒特征码的个数。
接下来N行,每行表示一个病毒特征码,特征码字符串长度在20—200之间。
每个病毒都有一个编号,依此为1—N。
不同编号的病毒特征码不会相同。
在这之后一行,有一个整数M(1<=M<=1000),表示网站数。
接下来M行,每行表示一个网站源码,源码字符串长度在7000—10000之间。
每个网站都有一个编号,依此为1—M。
以上字符串中字符都是ASCII码可见字符(不包括回车)。
Output
依次按如下格式输出按网站编号从小到大输出,带病毒的网站编号和包含病毒编号,每行一个含毒网站信息。
web 网站编号: 病毒编号 病毒编号 …
冒号后有一个空格,病毒编号按从小到大排列,两个病毒编号之间用一个空格隔开,如果一个网站包含病毒,病毒数不会超过3个。
最后一行输出统计信息,如下格式
total: 带病毒网站数
冒号后有一个空格。

Sample Input
3
aaa
bbb
ccc
2
aaabbbccc
bbaacc
Sample Output
web 1: 1 2 3
total: 1

AC代码:

#include 
#include 
#include 
#include 
using namespace std;
struct AC {
    int child[200*500+500][128], fail[200*500+500], end[200*500+500];
    int root, cnt;
    int newnode() {
        for (int i = 0; i < 128; i++)   child[cnt][i] = -1;
        end[cnt++] = -1;
        return cnt-1;
    }
    void init() {
        cnt = 0;
        root = newnode();
    }
    void insert(int id, char s[]) {
        int len = strlen(s), cur = root;
        for (int i = 0; i < len; cur = child[cur][s[i++]])
            if (child[cur][s[i]] == -1) child[cur][s[i]] = newnode();
        end[cur] = id;
    }
    void setFail() {
        queue <int> que;
        fail[root] = root;
        for (int i = 0; i < 128; i++) {
            if (child[root][i] == -1) {
                child[root][i] = root;
            } else {
                fail[child[root][i]] = root;
                que.push(child[root][i]);
            }
        }
        while (!que.empty()) {
            int u = que.front();
            for (int i = 0; i < 128; i++) {
                if (child[u][i] != -1) {
                    fail[child[u][i]] = child[fail[u]][i];
                    que.push(child[u][i]);
                } else {
                    child[u][i] = child[fail[u]][i];
                }
            }
            que.pop();
        }
    }
    bool flag[505];
    bool query(int n, int id, char s[]) {
        int len = strlen(s);
        int cur = root;
        bool mark = false;
        memset(flag, 0, sizeof(flag));
        for (int i = 0; i < len; i++) {
            cur = child[cur][s[i]];
            int tmp = cur;
            while (tmp != root) {
                if (end[tmp] != -1) flag[end[tmp]] = mark = true;
                tmp = fail[tmp];
            }
        }
        if (!mark)  return false;
        printf("web %d:", id);
        for (int i = 1; i <= n; i++)
            if (flag[i])
                printf(" %d", i);
        printf("\n");
        return true;
    }
} ac;
char str[10005];
int main() {
    int n, m, ans;
    while (scanf("%d", &n) != EOF) {
        ac.init();
        for (int i = 1; i <= n; i++) {
            scanf("%s", str);
            ac.insert(i, str);
        }
        ac.setFail();
        ans = 0;
        scanf("%d", &m);
        for (int i = 1; i <= m; i++) {
            scanf("%s", str);
            ans += ac.query(n, i, str);
        }
        printf("total: %d\n", ans);
    }
    return 0;
}

接下来看一道稍微有些变化的题(虽然还是基础水题)
【ZOJ3034】 Detect the Virus
One day, Nobita found that his computer is extremely slow. After several hours’ work, he finally found that it was a virus that made his poor computer slow and the virus was activated by a misoperation of opening an attachment of an email.
Nobita did use an outstanding anti-virus software, however, for some strange reason, this software did not check email attachments. Now Nobita decide to detect viruses in emails by himself.
To detect an virus, a virus sample (several binary bytes) is needed. If these binary bytes can be found in the email attachment (binary data), then the attachment contains the virus.
Note that attachments (binary data) in emails are usually encoded in base64. To encode a binary stream in base64, first write the binary stream into bits. Then take 6 bits from the stream in turn, encode these 6 bits into a base64 character according the following table:
That is, translate every 3 bytes into 4 base64 characters. If the original binary stream contains 3k + 1 bytes, where k is an integer, fill last bits using zero when encoding and append ‘==’ as padding. If the original binary stream contains 3k + 2 bytes, fill last bits using zero when encoding and append ‘=’ as padding. No padding is needed when the original binary stream contains 3k bytes.

Value Encoding Value Encoding
0 A 32 g
1 B 33 h
2 C 34 i
3 D 35 j
4 E 36 k
5 F 37 l
6 G 38 m
7 H 39 n
8 I 40 o
9 J 41 p
10 K 42 q
11 L 43 r
12 M 44 s
13 N 45 t
14 O 46 u
15 P 47 v
16 Q 48 w
17 R 49 x
18 S 50 y
19 T 51 z
20 U 52 0
21 V 53 1
22 W 54 2
23 X 55 3
24 Y 56 4
25 Z 57 5
26 a 58 6
27 b 59 7
28 c 60 8
29 d 61 9
30 e 62 +
31 f 63 /

For example, to encode ‘hello’ into base64, first write ‘hello’ as binary bits, that is: 01101000 01100101 01101100 01101100 01101111
Then, take 6 bits in turn and fill last bits as zero as padding (zero padding bits are marked in bold): 011010 000110 010101 101100 011011 000110 111100
They are 26 6 21 44 27 6 60 in decimal. Look up the table above and use corresponding characters: aGVsbG8
Since original binary data contains 1 * 3 + 2 bytes, padding is needed, append ‘=’ and ‘hello’ is finally encoded in base64: aGVsbG8=
Section 5.2 of RFC 1521 describes how to encode a binary stream in base64 much more detailedly:
Here is a piece of ANSI C code that can encode binary data in base64. It contains a function, encode (infile, outfile), to encode binary file infile in base64 and output result to outfile.

Input
Input contains multiple cases (about 15, of which most are small ones). The first line of each case contains an integer N (0 <= N <= 512). In the next N distinct lines, each line contains a sample of a kind of virus, which is not empty, has not more than 64 bytes in binary and is encoded in base64. Then, the next line contains an integer M (1 <= M <= 128). In the following M lines, each line contains the content of a file to be detected, which is not empty, has no more than 2048 bytes in binary and is encoded in base64.
There is a blank line after each case.
Output
For each case, output M lines. The ith line contains the number of kinds of virus detected in the ith file.
Output a blank line after each case.

Sample Input
3
YmFzZTY0
dmlydXM=
dDog
1
dGVzdDogdmlydXMu

1
QA==
2
QA==
ICAgICAgICA=
Sample Output
2

1
0

Hint
In the first sample case, there are three virus samples: base64, virus and t: , the data to be checked is test: virus., which contains the second and the third, two virus samples.

思路很简单,先解码,然后就是模板题了

#include 
#include 
#include 
#include 
#define DICNUM 256
using namespace std;
struct ACautomation {
    int cnt, root;
    int child[50000+50][DICNUM], end[50000+50], fail[50000+50];
    int Map[256];
    int newnode() {
        for (int i = 0; i < DICNUM; i++)    child[cnt][i] = -1;
        end[cnt++] = -1;
        return cnt-1;
    }
    void init() {
        cnt = 0;
        root = newnode();
        for (int i = 0; i < 26; i++)    Map[i+'A'] = i;
        for (int i = 26; i < 52; i++)   Map[i-26+'a'] = i;
        for (int i = 52; i < 62; i++)   Map[i-52+'0'] = i;
        Map['+'] = 62, Map['/'] = 63;
    }
    int tmp[5000+5], tmpl;
    void ReEncode(char s[]) {
        memset(tmp, 0, sizeof(tmp));
        tmpl = 0;
        for (int i = 0, len = 0, x = 0; s[i] && s[i] != '='; i++) {
            len += 6, x = (x<<6)|Map[s[i]];
            if (len >= 8) {
                tmp[tmpl++] = (x>>(len-8))&0xff;
                len -= 8;
            }
        }
    }
    void insert(int id, char s[]) {
        ReEncode(s);
        int cur = root;
        for (int i = 0; i < tmpl; cur = child[cur][tmp[i++]])
            if (child[cur][tmp[i]] == -1)   child[cur][tmp[i]] = newnode();
        end[cur] = id;
    }
    void setFail() {
        queue <int> que;
        fail[root] = root;
        for (int i = 0; i < DICNUM; i++) {
            if (child[root][i] == -1) {
                child[root][i] = root;
            } else {
                fail[child[root][i]] = root;
                que.push(child[root][i]);
            }
        }
        while (!que.empty()) {
            int cur = que.front();
            for (int i = 0; i < DICNUM; i++) {
                if (child[cur][i] == -1) {
                    child[cur][i] = child[fail[cur]][i];
                } else {
                    fail[child[cur][i]] = child[fail[cur]][i];
                    que.push(child[cur][i]);
                }
            }
            que.pop();
        }
    }
    bool flag[50000+50];
    int query(int id, char s[]) {
        ReEncode(s);
        int cur = root, ret = 0;
        memset(flag, 0, sizeof(flag));
        for (int i = 0; i < tmpl; i++) {
            cur = child[cur][tmp[i]];
            int t = cur;
            while (t != root) {
                if (end[t] != -1 && !flag[end[t]])  ret++, flag[end[t]] = true;
                t = fail[t];
            }
        }
        return ret;
    } 
} ACauto;
int main() {
    int n, m;
    char str[5000+5];
    while (scanf("%d", &n) != EOF) {
        ACauto.init();
        for (int i = 1; i <= n; i++) {
            scanf("%s", str);
            ACauto.insert(i, str);
        }
        ACauto.setFail();
        scanf("%d", &m);
        for (int i = 1; i <= m; i++) {
            scanf("%s", str);
            printf("%d\n", ACauto.query(i, str));
        }
        printf("\n");
    }
    return 0;
}

AC自动机的玄学博大精深,还可以和dp扯上关系,还是多练题为好

你可能感兴趣的:(AC自动机)