AC自动机是tire树和KMP的结合,若模版串长度为l,有n个模版串,文本串长度为s,则AC自动机复杂度为O(l*n+s)。
KMP是一个模版串和一个文本串,最常见的是问模版串在文本串中的出现次数。所以失配的时候就去找模版串和文本串当前后缀匹配最长的前缀,若再失配以此类推。
但是若有多个模版串,若和当前模版串失配,可能会和其他模版串匹配,或者其他模版串的前缀能比当前模版串的前缀和文本串的后缀匹配的多,这时失配函数就不能像KMP中的失配函数一样了,而是应该指向能匹配最多的模版串此时应该比较的位置。
最常见的是给你文本串和一些模版串,让你找出模版串在文本串中出现的次数。
首先把模版串建立一个tire树,如果f[u]为匹配到节点u而不能往后接着匹配的失配指针,那么当每次失配的时候只要在去看失配指针的位置能不能匹配。若匹配到了模版串标记位置,从头开始到这个位置是一个模版串。不管当前节点是否匹配上,都要考虑的是同一个节点可能对应多个字符串的结尾。假设有模版串101,0,01,文本串101,一开始先匹配1,再匹配1后面的0,都匹配上了,因为这个0是101这一路的0,所以不是单词节点,但这并不代表目前没有匹配上,因为0已经匹配上了。所以需要一个last数组,last[j]代表沿着节点j失配指针往回走时,遇到的下一个单词节点编号。模版指针的位置一直是在以文本指针当前位置为结尾的能匹配最长后缀的那个模版上。所以就算当前节点不是单词节点,也需要判断last是否匹配。
失配函数怎么写呢?定义一个队列,首先把和tire树起点相连的节点加入队列,它们的f和last都是0。接下来按BFS顺序计算失配指针,重要思想是如果当前节点r有子节点u=ch[r][c],那么算f[u]就看ch[f[r]][c]存不存在(因为是BFS顺序,所以f[r]已经计算过了),若不存在继续寻找f[f[r]],若最后都没找到就是0了,last[u]则要判断f[u]是不是单词节点,如果是,last[u]=f[u],否则last[u]=last[f[u]]。
寻找出现次数的代码:
struct AhoCorasickAutomata{ int ch[MAXNODE][SIGMA_SIZE],val[MAXNODE],f[MAXNODE],last[MAXNODE],cnt[MAXN],sz; void init(){ memset(ch[0],0,sizeof(ch[0])); sz=1; } int idx(char c){ return c-'a'; } void insert(char *s,int v){ int u=0; for(int i=0;s[i];i++){ int c=idx(s[i]); if(!ch[u][c]){ memset(ch[sz],0,sizeof(ch[sz])); val[sz]++; ch[u][c]=sz++; } u=ch[u][c]; } val[u]=v; } void get_fail(){ queue<int> q; f[0]=0; for(int c=0;c<SIGMA_SIZE;c++){ int u=ch[0][c]; if(u){ f[u]=last[u]=0; q.push(u); } } while(!q.empty()){ int r=q.front(); q.pop(); for(int c=0;c<SIGMA_SIZE;c++){ int u=ch[r][c]; if(!u) continue; q.push(u); int v=f[r]; while(v&&!ch[v][c]) v=f[v]; f[u]=ch[v][c]; //ch最开始已经初始化为0 last[u]=val[f[u]]?f[u]:last[f[u]]; } } } //其实在统计出现次数中这个函数的功能不是打印,只是这么叫比较通用 void print(int j){ if(j){ cnt[val[j]]++; print(last[j]); } } //在文本串中寻找模版串出现次数 void find(char *T){ int j=0; //当前节点编号 for(int i=0;T[i];i++){ int c=idx(T[i]); while(j&&!ch[j][c]) j=f[j]; j=ch[j][c]; if(val[j]) print(j); else if(last[j]) print(last[j]); } } };
struct AhoCorasickAutomata{ int ch[MAXNODE][SIGMA_SIZE],f[MAXNODE],match[MAXNODE],sz; void init(){ memset(ch[0],0,sizeof(ch[0])); memset(f,0,sizeof(f)); memset(match,0,sizeof(match)); sz=1; } void insert(char *s){ int u=0; for(int i=0;s[i];i++){ int c=idx[s[i]]; if(!ch[u][c]){ memset(ch[sz],0,sizeof(ch[sz])); ch[u][c]=sz++; } u=ch[u][c]; } match[u]=1; } void get_fail(){ queue<int> q; f[0]=0; for(int c=0;c<SIGMA_SIZE;c++){ int u=ch[0][c]; if(u) q.push(u); } while(!q.empty()){ int r=q.front(); q.pop(); for(int c=0;c<SIGMA_SIZE;c++){ int u=ch[r][c]; if(!u){ ch[r][c]=ch[f[r]][c]; continue; } q.push(u); int v=f[r]; f[u]=ch[v][c]; match[u]|=match[f[u]]; } } } }ac; double get_prob(int u,int L){ if(!L) return 1; double &ans=d[u][L]; if(vis[u][L]) return ans; vis[u][L]=1; ans=0; for(int i=0;i<N;i++) if(!ac.match[ac.ch[u][i]]) ans+=prob[i]*get_prob(ac.ch[u][i],L-1); return ans; }