HDU 2222 Keywords Search【ACAM】

HDU 2222 Kerwords Search

代码风格模仿自:USETC每周算法讲解,AC自动机,郭老师!

输入一个T,对于每一个T给你一个n,接下来输入P个模式串,然后给你一个L串,L中出现了多少P中的串。

这是一道比较基础的AC自动机问题,我只不过是为了不想用板子而直接写,因此写此博文。
将容易遗漏或者写错的地方进行了注释的标注,用于提示自己,也用于与大家进行分享。

#include
#include
#include
#include
#include
using namespace std;

const int maxn = 1e6 + 10;
const int mtop = 5e5 + 10;

struct Aho {
	struct state {// 带fail指针的Trie
		int nxt[26];
		int fail, cnt;
	}stateTable[mtop];

	int sz;

	queue<int> q;

	void init() {
		while (!q.empty()) q.pop();
		for (int i = 0; i < mtop; i++) {
			memset(stateTable[i].nxt, 0, sizeof stateTable[i].nxt);
			stateTable[i].fail = stateTable[i].cnt = 0;
		}
		sz = 1;
	}

	void insert(char *S) {
		int len = strlen(S);
		int now = 0;
		for (int i = 0; i < len; i++) {
			char c = S[i];
			if (!stateTable[now].nxt[c - 'a']) {
				stateTable[now].nxt[c - 'a'] = sz++;
			}
			now = stateTable[now].nxt[c - 'a'];
			//stateTable[now].cnt++;
		}
		stateTable[now].cnt++;
	}

	void build() {
		stateTable[0].fail = -1;
		q.push(0);
		while (!q.empty()) {
			int u = q.front();
			q.pop();
			for (int i = 0; i < 26; i++) {
				if (stateTable[u].nxt[i]) {
					if (u == 0) {
						stateTable[stateTable[u].nxt[i]].fail = 0;
					}
					else {
						int v = stateTable[u].fail;
						while (v != -1) {
							if (stateTable[v].nxt[i]) {
								stateTable[stateTable[u].nxt[i]].fail = stateTable[v].nxt[i];
								break;// keep lengest?
							}
							v = stateTable[v].fail;
						}
						if (v == -1) stateTable[stateTable[u].nxt[i]].fail = 0;
					}
					q.push(stateTable[u].nxt[i]);
				}
			}
		}
	}
	int Get(int u) {
		int res = 0;
		while (u) {
			res += stateTable[u].cnt;
			stateTable[u].cnt = 0;//work for Problem 2222
			u = stateTable[u].fail;
		}
		return res;
	}


	int match(char *S) {
		int len = strlen(S);
		int now = 0; int res = 0;
		for (int i = 0; i < len; i++) {
			int c = S[i];
			if (stateTable[now].nxt[c - 'a']) {
				now = stateTable[now].nxt[c - 'a'];
			}
			else {
				int p = stateTable[now].fail;
				while (p != -1 && stateTable[p].nxt[c-'a'] == 0) {
					p = stateTable[p].fail;
				}
				if (p == -1) {
					now = 0;
				}
				else {
					now = stateTable[p].nxt[c-'a'];
				}
			}
			if (stateTable[now].cnt) {// Sum of A of now->cnt;
				res += Get(now);
			}
		}
		return res;
	}
}aho;


int main() {
	int T;
	scanf("%d", &T);
	char S[maxn];
	while (T--) {
		int n;
		scanf("%d", &n);
		aho.init();
		for (int i = 1; i <= n; i++) {
			scanf("%s", S);
			aho.insert(S);
		}
		aho.build();
		scanf("%s", S);
		int ans =  aho.match(S);
		cout << ans << endl;
	}
	



	return 0;
}

用于记录WA点的代码(容易出现问题的地方都通过注释进行了标记)

#include
#include
#include
#include
#include
#include

using namespace std;

const int maxn = 1e6 + 10;
const int maxtop = 5e5 + 10;

struct Aho {

	struct state {
		int nxt[26];
		int cnt, fail;
	}stateTable[maxtop];

	int sz;

	queue<int> q;

	void init() {
		while (!q.empty()) q.pop();
		for (int i = 0; i < maxtop; i++) {
			memset(stateTable[i].nxt, 0, sizeof stateTable[i].nxt);
			stateTable[i].cnt = stateTable[i].fail = 0;
		}
		sz = 1;
	}

	void insert(char *S) {
		int len = strlen(S);
		int now = 0;
		for (int i = 0; i < len; i++) {
			char c = S[i];
			if (!stateTable[now].nxt[c - 'a'])
				stateTable[now].nxt[c - 'a'] = sz++;
			now = stateTable[now].nxt[c - 'a'];//lose
		}
		stateTable[now].cnt++;
	}

	void build() {
		stateTable[0].fail = -1;
		q.push(0);
		while (!q.empty()) {
			int u = q.front();
			q.pop();
			for (int i = 0; i < 26; i++) {
				if (stateTable[u].nxt[i]) {
					if (u == 0) {
						stateTable[stateTable[u].nxt[i]].fail = 0;
					}
					else {
						int v = stateTable[u].fail;
						while (v != -1) {
							if (stateTable[v].nxt[i]) {//wa
								stateTable[stateTable[u].nxt[i]].fail = stateTable[v].nxt[i];
								break;// lose lose lose !!!!!
							}
							v = stateTable[v].fail;
						}
						if (v == -1) {
							stateTable[stateTable[u].nxt[i]].fail = 0;
						}
						/*else {
							stateTable[stateTable[u].nxt[i]].fail = stateTable[v].nxt[i];
						}*/

						//wa
					}
					q.push(stateTable[u].nxt[i]);
				}
			}
		}
	}

	int Get(int u) {
		int res = 0;
		while (u) {
			res += stateTable[u].cnt;
			stateTable[u].cnt = 0;
			u = stateTable[u].fail;
		}
		return res;
	}

	int match(char * S) {
		int now = 0;
		int res = 0;
		int len = strlen(S);
		for (int i = 0; i < len; i++) {
			char c = S[i];
			if (stateTable[now].nxt[c - 'a']) {
				now = stateTable[now].nxt[c - 'a'];
			}
			else {
				int p = stateTable[now].fail;
				while (p != -1 && stateTable[p].nxt[c - 'a'] == 0) {//wa
					p = stateTable[p].fail;
				}
				if (p == -1) {
					now = 0;
				}
				else {
					now = stateTable[p].nxt[c - 'a'];//wa
				}

			}
			if (stateTable[now].cnt) {
				res += Get(now);
			}
		}
		return res;
	}


}aho;


int main() {
	int T;
	scanf("%d", &T);
	char S[maxn];
	while (T--) {
		int n;
		scanf("%d", &n);
		aho.init();
		for (int i = 1; i <= n; i++) {
			scanf("%s", S);
			aho.insert(S);
		}
		aho.build();
		scanf("%s", S);
		int ans = aho.match(S);
		cout << ans << endl;
	}

	return 0;
}

你可能感兴趣的:(Data,Structure)