【AC自动机】hdu2222 hdu2896 hdu3065 zoj3430 poj2778 hdu2243

AC自动机用于多个模式串与多个母串的匹配。
第一步:根据模式串建立字典树

int len=strlen(w), r=root;
    for(int i=0;i<len;++i)
    {
    if(tree[r].ch[w[i]])r=tree[r].ch[w[i]];
        else r=tree[r].ch[w[i]]=++cnt;
    }
    ++tree[r].cnt;//cnt为在该节点结束的模式串的数量

第二步:计算每一个节点的fail指针(与kmp中next数组相似)。首先找到u的父亲的fail指针v。若v对应的儿子不为空,则u的fail指针指向v。否则访问v的fail指针。

void bfs()
{
    head=tail=0;
    q[++tail]=0;
    int p, tmp;
    while(head<tail)
    {
        p=q[++head];
        for(int i=0;i<128;++i)//根据模式串的字符大小确定i的范围
        {
            if(tree[p].ch[i])
            {
                tmp=tree[p].ch[i];
                if(p)tree[tmp].fail=tree[tree[p].fail].ch[i];
                q[++tail]=tmp;
            }
            else tree[p].ch[i]=tree[tree[p].fail].ch[i];
        }
    }
}

第三步:在字典树中查找

void query()
{
    int len=strlen(w), ans=0, p=0, tmp;
    for(int i=0;i<len;++i)
    {
        p=tree[p].ch[w[i]-'a'];
        tmp=p;
        while(tree[tmp].pos)
        {
            ans+=tree[tmp].pos;
            tree[tmp].pos=0;
            tmp=tree[tmp].fail;
        }
    }
    printf("%d\n",ans);
}

模板题:
hdu2222
题目大意:问在一个母串中有多少个模式串出现
直接上模板

#include <iostream>
#include <cstdio>
#include <cstring>
#define MAXN 500005
#define MAXM 1000005
using namespace std;

int n;
char w[MAXM];

struct node
{
    int pos, ch[26], fail;
    inline void init()
    {
        fail=pos=0;
        memset(ch,0,sizeof ch);
    }
}tree[MAXN];
int cnt, root;

void add()
{
    int len=strlen(w), r=root;
    for(int i=0;i<len;++i)
    {
        w[i]-='a';
        if(tree[r].ch[w[i]])r=tree[r].ch[w[i]];
        else r=tree[r].ch[w[i]]=++cnt;
    }
    ++tree[r].pos;
}

int head, tail, q[MAXN];
void bfs()
{
    head=tail=0;
    q[++tail]=0;
    int p, tmp;
    while(head<tail)
    {
        p=q[++head];
        for(int i=0;i<26;++i)
        {
            if(tree[p].ch[i])
            {
                tmp=tree[p].ch[i];
                if(p)tree[tmp].fail=tree[tree[p].fail].ch[i];
                q[++tail]=tmp;
            }
            else tree[p].ch[i]=tree[tree[p].fail].ch[i];
        }
    }
}

void query()
{
    int len=strlen(w), ans=0, p=0, tmp;
    for(int i=0;i<len;++i)
    {
        p=tree[p].ch[w[i]-'a'];
        tmp=p;
        while(tree[tmp].pos)
        {
            ans+=tree[tmp].pos;
            tree[tmp].pos=0;
            tmp=tree[tmp].fail;
        }
    }
    printf("%d\n",ans);
}

int main()
{
    int cas;
    scanf("%d",&cas);
    while(cas--)
    {
        cnt=root=0;
        scanf("%d",&n);
        for(int i=0;i<n;++i)
        {
            scanf("%s",w);
            add();
        }
        bfs();
        scanf("%s",w);
        query();
        for(int i=0;i<=cnt;++i)
            tree[i].init();
    }
    return 0;
}

hdu2896
一定要注意模式串字符的范围。蒟蒻在此处RE了几次才发现…
提供指针版
不过还是数组版的好调试一些 233

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define MAXN 100005
#define MAXM 10005
using namespace std;

int n, m;
bool vis[MAXN];
struct node
{
    int pos;
    node *ch[128], *fail;
    inline void init()
    {
        fail=0, pos=0;
        memset(ch,0,sizeof ch);
    }
}tree[MAXN], *cnt, *root;

char w[MAXM];
void add(int j)
{
    int len=strlen(w), v;
    node *r=root;
    for(int i=0;i<len;++i)
    {
        v=w[i];
        if(r->ch[v])r=r->ch[v];
        else r=r->ch[v]=++cnt;
    }
    r->pos=j;
    vis[j]=1;
}

node *q[MAXN];
int head, tail;
void bfs()
{
    head=tail=0;
    q[++tail]=root;
    node *p, *son, *tmp;
    while(head<tail)
    {
        p=q[++head];
        for(int i=0;i<128;++i)
        {
            son=p->ch[i];
            if(son)
            {
                if(p==root)son->fail=p;
                else
                {
                    tmp=p->fail;
                    while(tmp)
                    {
                        if(tmp->ch[i])
                        {
                            son->fail=tmp->ch[i];
                            break;
                        }
                        tmp=tmp->fail;
                    }
                    if(!tmp)son->fail=root;
                }
                q[++tail]=son;
            }
        }
    }
}

int num[MAXM], tmp, tot;
void query(int j)
{
    int len=strlen(w), v;
    node *p=root, *temp;
    for(int i=0;i<len;++i)
    {
        v=w[i];
        while(!p->ch[v]&&p!=root)
            p=p->fail;
        p=p->ch[v];
        if(!p)p=root;
        temp=p;
        while(vis[temp->pos])
        {
            num[++tmp]=temp->pos;
            vis[temp->pos]=0;
            temp=temp->fail;
        }
    }
    if(tmp)
    {
        sort(num+1,num+tmp+1);
        printf("web %d:",j);
        for(int i=1;i<=tmp;++i)
        {
            printf(" %d",num[i]);
            vis[num[i]]=1;
        }
        ++tot, tmp=0;
        puts("");
    }
}

int main()
{
    while(~scanf("%d",&n))
    {
        root=cnt=tree, tot=0;
        for(int i=1;i<=n;++i)
        {
            scanf("%s",w);
            add(i);
        }
        bfs();
        scanf("%d",&m);
        for(int i=1;i<=m;++i)
        {
            scanf("%s",w);
            query(i);
        }
        printf("total: %d\n",tot);
        for(node *p=tree;p<=cnt;++p)
            p->init();
    }
    return 0;
}

hdu3065
这里就涉及到了计数的问题。
不进行标记。在之前的模板中都加了一个优化,即访问了一个节点就打上标记。

#include <iostream>
#include <cstdio>
#include <cstring>
#define MAXM 2000005
#define MAXN 50005
using namespace std;

int n, ans[1005];
char w[1005][55], s[MAXM];

struct node
{
    node *ch[26], *fail;
    int pos;
    void init()
    {
        fail=0, pos=0;
        memset(ch,0,sizeof ch);
    }
}tree[MAXN], *root, *cnt;

void add(node *r,int j)
{
    int len=strlen(w[j]), v;
    for(int i=0;i<len;++i)
    {
        v=w[j][i]-'A';
        if(r->ch[v])r=r->ch[v];
        else r=r->ch[v]=++cnt;
    }
    r->pos=j;
}

int head, tail;
node *q[MAXN];
void bfs()
{
    head=tail=0;
    q[++tail]=root;
    node *p, *son, *tmp;
    while(head<tail)
    {
        p=q[++head];
        for(int i=0;i<26;++i)
        {
            son=p->ch[i];
            if(son)
            {
                if(p==root)son->fail=p;
                else
                {
                    tmp=p->fail;
                    while(tmp)
                    {
                        if(tmp->ch[i])
                        {
                            son->fail=tmp->ch[i];
                            break;
                        }
                        tmp=tmp->fail;
                    }
                    if(!tmp)son->fail=root;
                }
                q[++tail]=son;
            }
        }
    }
}

void query()
{
    int len=strlen(s), v;
    node *p=root, *temp;
    for(int i=0;i<len;++i)
    {
        if(s[i]<'A'||s[i]>'Z')
        {
            p=root;
            continue;
        }
        v=s[i]-'A';
        while(p)
        {
            if(p->ch[v])
            {
                p=p->ch[v];
                break;
            }
            p=p->fail;
        }
        if(!p)p=root;
        else
        {
            temp=p;
            while(temp)
            {
                if(temp->pos)++ans[temp->pos];
                temp=temp->fail;
            }
        }
    }
}

int main()
{
    while(~scanf("%d",&n))
    {
        root=cnt=tree;
        for(int i=1;i<=n;++i)
        {
            scanf("%s",w[i]);
            add(root,i);
        }
        bfs();
        scanf("%s",s);
        query();
        for(int i=1;i<=n;++i)
            if(ans[i])
            {
                printf("%s: %d\n",w[i],ans[i]);
                ans[i]=0;
            }
        for(node *p=tree;p<=cnt;++p)
            p->init();
    }
    return 0;

zoj3430
题目大意:给你一个加密规则:将所有字符写成二进制并串联起来。然后每6位数组成一个新的二进制数,再转化为十进制数,根据密码表翻译成字符。若len%3==1,就再加上=。若len%3==2,就加上==。
现在已知一些被加密后的模式串和一些加密后的母串。求每一个母串中出现了多少个模式串。

这道题巧妙地运用位运算可以很方便的还原字符串。要注意虽然密码串是从‘0’到‘z’,但原串是0到255。而且是多组数据,一定要清零。被坑惨了…

#include <iostream>
#include <cstdio>
#include <cstring>
#define MAXN 50050
using namespace std;

char w[MAXN];
bool vis[MAXN];
int n, temp[MAXN], cnt;

struct node
{
    int ch[256], fail, cnt;
    void init()
    {
        cnt=fail=0;
        memset(ch,0,sizeof ch);
    }
}tree[MAXN];

void add(int r)
{
    for(int i=1;i<=temp[0];++i)
    {
        if(tree[r].ch[temp[i]])r=tree[r].ch[temp[i]];
        else r=tree[r].ch[temp[i]]=++cnt;
    }
    tree[r].cnt=1;
}

int key[256];
void table()
{
    for(int i=0;i<26;++i)key[i+'A']=i;
    for(int i=0;i<26;++i)key[i+'a']=i+26;
    for(int i=0;i<10;++i)key[i+'0']=i+52;
    key['+']=62, key['/']=63;
}

void change(char s[])
{
    temp[0]=0;
    int len, x=0;
    for(len=strlen(s);s[len-1]=='=';--len);
    for(int i=0, tmp=0;i<len;++i)
    {
        x=(x<<6)|key[s[i]], tmp+=6;
        if(tmp>=8)
        {
            temp[++temp[0]]=(x>>(tmp-8))&255;
            tmp-=8;
        }
    }
}

int head, tail, q[MAXN];
void bfs()
{
    head=tail=0;
    q[++tail]=0;
    int p, tmp;
    while(head<tail)
    {
        p=q[++head];
        for(int i=0;i<256;++i)
        {
            if(tree[p].ch[i])
            {
                tmp=tree[p].ch[i];
                if(p)tree[tmp].fail=tree[tree[p].fail].ch[i];
                q[++tail]=tmp;
            }
            else tree[p].ch[i]=tree[tree[p].fail].ch[i];
        }
    }
}

void query(int w[])
{
    int ans=0, p=0, tmp;
    for(int i=1;i<=w[0];++i)
    {
        p=tree[p].ch[w[i]];
        tmp=p;
        while(vis[tmp])
        {
            ans+=tree[tmp].cnt;
            vis[tmp]=0;
            tmp=tree[tmp].fail;
        }
    }
    printf("%d\n",ans);
}

int main()
{
    table();
    while(~scanf("%d",&n))
    {
        cnt=0;
        for(int i=1;i<=n;++i)
        {
            scanf("%s",w);
            change(w);
            add(0);
        }
        bfs();
        scanf("%d",&n);
        while(n--)
        {
            memset(vis,1,sizeof vis);
            scanf("%s",w);
            change(w);
            query(temp);
        }
        for(int i=0;i<=cnt;++i)
            tree[i].init();
        puts("");
    }
    return 0;
}

poj2778
题目大意:有 n 个病毒的DNA序,求长度为 l 的DNA序中不含病毒的个数。

不得不承认蒟蒻是看了题解的…
首先对这 n 个病毒DNA序构建AC自动机。自动集中的节点都是由有向边连接而成的,那么将那些打了标记的节点删除,剩下的就是一个图。问题就转化为从起点0开始,走N步有多少种方案。
这样就是一个矩阵加速的问题了。

#include <iostream>
#include <cstdio>
#include <cstring>
#define MAXN 110
#define LL long long int
#define mod 100000
using namespace std;

struct mat
{
    LL num[MAXN][MAXN], n;
    void init()
    {
        memset(num,0,sizeof num);
        n=0;
    }
    mat operator * (const mat &a)const
    {
        mat ans;
        ans.init();
        ans.n=n;
        for(int i=0;i<=n;++i)
            for(int j=0;j<=n;++j)
                for(int k=0;k<=n;++k)
                    ans.num[i][k]=(ans.num[i][k]+num[i][j]*a.num[j][k])%mod;
        return ans;
    }
}ans;

mat power(mat a,int pos)
{
    mat ans=a;
    while(pos)
    {
        if(pos&1)ans=ans*a;
        a=a*a;
        pos>>=1;
    }
    return ans;
}

inline int getid(char a)
{
    if(a=='A')return 0;
    if(a=='C')return 1;
    if(a=='G')return 2;
    return 3;
}

struct node
{
    int ch[5], fail, cnt;
    void init()
    {
        fail=cnt=0;
        memset(ch,0,sizeof ch);
    }
}tree[MAXN];

int cnt, root;
char w[MAXN];
void add()
{
    int len=strlen(w), r=root, v;
    for(int i=0;i<len;++i)
    {
        v=getid(w[i]);
        if(tree[r].ch[v])r=tree[r].ch[v];
        else r=tree[r].ch[v]=++cnt;
    }
    ++tree[r].cnt;
}

int q[MAXN], head, tail;
void bfs()
{
    head=tail=0;
    q[++tail]=0;
    int p, tmp;
    while(head<tail)
    {
        p=q[++head];
        for(int i=0;i<4;++i)
        {
            if(tree[p].ch[i])
            {
                tmp=tree[p].ch[i];
                if(p)
                {
                    tree[tmp].fail=tree[tree[p].fail].ch[i];
                    tree[tmp].cnt+=tree[tree[tree[p].fail].ch[i]].cnt;
                }
                q[++tail]=tmp;
            }
            else
                tree[p].ch[i]=tree[tree[p].fail].ch[i];
        }
    }
}

void build()
{
    ans.init();
    for(int i=0;i<=cnt;++i)
    {
        for(int j=0;j<4;++j)
        {
            if(tree[tree[i].ch[j]].cnt)continue;
            ++ans.num[i][tree[i].ch[j]];
        }
    }
    ans.n=cnt;
}

int n, m;
LL out;
int main()
{
    while(~scanf("%d%d",&n,&m))
    {
        root=cnt=0;
        for(int i=0;i<n;++i)
        {
            scanf("%s",w);
            add();
        }
        bfs();
        build();
        ans=power(ans,m-1);
        out=0;
        for(int i=0;i<=cnt;++i)out=(out+ans.num[0][i])%mod;
        printf("%d\n",out);
        for(int i=0;i<=cnt;++i)tree[i].init();
    }
    return 0;
}

hdu2243
这道题就是上面那道的加强版。
蒟蒻目前TLE中,期待持续更新…

你可能感兴趣的:(AC自动机)