http://acm.hdu.edu.cn/showproblem.php?pid=2222
题意:给出n个串,然后给一篇文章,问这n个串有多少个在文章里面出现过。。。
trick:n个串可能有相同的,需按照不同串处理。
分析:AC自动机模板题,自己按照算法思想写的,写得郁闷。。贴代码。。
不知道怎么要800+MS,别人的都基本上200+MS。。。
7月21号重新研究别人代码,终于找到差距了。。减到了187MS。。。附代码和改正说明:
7.21代码。。。
#include<iostream> #include<stdio.h> using namespace std; const int N=1000010; int n, a[N], up; char s[N]; struct node { int a[26]; int cnt, fail; void init() { memset(a, -1, sizeof(a)); cnt = fail = 0; } }trie[N]; inline void insert(char *s) { int p=0; while(*s) { if(trie[p].a[*s-'a']==-1) { trie[up].init(); trie[p].a[*s-'a'] = up++; } p = trie[p].a[*s-'a']; s++; } trie[p].cnt++; } int q[N], head, tail; void bfs() { int i, p, p1, p2; head = tail = 0; for(i=0; i<26; i++) { if(trie[0].a[i]!=-1) { p = trie[0].a[i]; trie[p].fail = 0; q[tail++] = p; } } while(head<tail) { p = q[head++]; for(i=0; i<26; i++) { if(trie[p].a[i]!=-1) { p2 = trie[p].a[i]; q[tail++] = p2; trie[p2].fail = 0; p1 = trie[p].fail; while(p1!=0 && trie[p1].a[i]==-1) p1 = trie[p1].fail; if(trie[p1].a[i]!=-1) { p1 = trie[p1].a[i]; trie[p2].fail = p1; } } } } } int query(char *s) { int p, p1; int cnt=0; p = 0; while(*s) { while(p!=0 && trie[p].a[*s-'a']==-1) p = trie[p].fail; p1 = trie[p].a[*s-'a']; if(p1!=-1) { p = p1; while(p1!=0 && trie[p1].cnt!=-1) //这里很关键。。。但注意:要改成一个以前不会出现的值-1。。。 { cnt += trie[p1].cnt; trie[p1].cnt = -1; p1 = trie[p1].fail; } /* while(p1!=0) //通过将本段while循环改成上面,减掉了不必要的再次查找一系列fail结点,从800+MS减到了187MS。。。 { if(trie[p1].flag==0) { cnt += trie[p1].cnt; trie[p1].flag = 1; } p1 = trie[p1].fail; } */ } s++; } return cnt; } int main() { int i, cas; scanf("%d", &cas); while(cas--) { scanf("%d", &n); gets(s); up = 1; trie[0].init(); for(i=0; i<n; i++) { gets(s); insert(s); } gets(s); bfs(); //for(i=0; i<up; i++) // printf("%d %d..\n", i, trie[i].fail); printf("%d\n", query(s)); } return 0; }
7.17号的代码。。。
//注意本题可能出现相同的keywords //为什么要800+MS呢。。。 #include<iostream> using namespace std; const int N=1000100; const int N1=250000; int n, q[N], head, tail; __int64 ans, cnt; char s[N]; struct node { int fail; int p[26]; int flag; bool visited; } tree[N1]; void insert(char *s) { int p=0; while(*s) { if(tree[p].p[*s-'a']==-1) tree[p].p[*s-'a'] = cnt++; p = tree[p].p[*s-'a']; s++; } tree[p].flag++; } void bfs() { int i, tmp; //把p做成指针 int pp; head = tail = 0; for(i=0; i<26; i++) { if(tree[0].p[i]!=-1) { pp = tree[0].p[i]; tree[pp].fail = 0; q[tail++] = pp; } } while(head<tail) { tmp = q[head++]; for(i=0; i<26; i++) { if(tree[tmp].p[i]!=-1) { q[tail++] = tree[tmp].p[i]; tree[tree[tmp].p[i]].fail = 0; pp = tree[tmp].fail; while(pp!=0 && tree[pp].p[i]==-1) pp = tree[pp].fail; if(tree[pp].p[i]!=-1) tree[tree[tmp].p[i]].fail = tree[pp].p[i]; } } } } int query(char *s) //flag==1 && count==0 cnt++; { int i=0, len=strlen(s); int tmp = 0, p=0; __int64 cnt=0; while(*s) { while(p!=0 && tree[p].p[*s-'a']==-1) p = tree[p].fail; //找到最末尾一个满足条件的。。。 tmp = tree[p].p[*s-'a']; if(tmp!=-1) { p = tmp; while(tmp!=0) { if(tree[tmp].flag!=0 && tree[tmp].visited==0)// && tree[tmp].count==0) { cnt += tree[tmp].flag; tree[tmp].visited = 1; } tmp = tree[tmp].fail; } } s++; } return cnt; } int main() { int i, cas; scanf("%d", &cas); while(cas--) { scanf("%d", &n); cnt = 1; for(i=0; i<N1&&i<50*n; i++) { tree[i].fail = 0; tree[i].flag = 0; tree[i].visited = 0; memset(tree[i].p, -1, sizeof(tree[i].p)); } gets(s); for(i=0; i<n; i++) { //scanf("%s", s); gets(s); insert(s); } //scanf("%s", s); gets(s); bfs(); /* for(i=1; i<=10; i++) // { printf("%d ..fail = %d\n", i, tree[i].fail); } */ ans = query(s); printf("%I64d\n", ans); } return 0; }