POJ 1625 Censored!(自动机DP+大数相加)

题意:给出包含n个可见字符的字符集,以下所提字符串均由该字符集中的字符构成。给出p个长度不超过10的字符串,求长为m且不包含上述p个字符串的字符串有多少个。

数据范围:1<=n,m<=50,0<=p<=10

状态设计:dp[i][j],i 步之内未经过危险结点且第 i 步到达结点 j 的路径数目。

状态转移:dp[i][j]=∑dp[i-1][k],在结点 k 加输入 s[i] 能跳到结点 j

初始化:dp[0][0]=1,对于其余的 i :dp[0][i]=0

注意:由于最后结果很大,而题中又没提到取模,所以要用到大数相加。

View Code
#include <stdio.h>

#include <stdlib.h>

#include <string.h>

#include <queue>

using namespace std;

#define NODE 110

int next[NODE][50];

int fail[NODE];

bool flag[NODE];

int n,L,m,node;



char ch[51];



int dp[51][NODE][100];



int cmp(const void *a,const void *b)

{

    return *(char*)a-*(char*)b;

}

void init()

{

    node=1;

    memset(next[0],0,sizeof(next[0]));

}

void add(int cur,int k)

{

    memset(next[node],0,sizeof(next[node]));

    flag[node]=0;

    next[cur][k]=node++;

}

int hash(char c)

{

    int min=0,max=n,mid;

    while(min+1!=max)

    {

        mid=min+max>>1;

        if(ch[mid]>c)    max=mid;

        else    min=mid;

    }

    return min;

}

void insert(char *s)

{

    int i,cur,k;

    for(i=cur=0;s[i];i++)

    {

        k=hash(s[i]);

        if(!next[cur][k])   add(cur,k);

        cur=next[cur][k];

    }

    flag[cur]=1;

}

void build_ac()

{

    queue<int>q;

    int cur,nxt,tmp,k;



    fail[0]=0;

    q.push(0);



    while(!q.empty())

    {

        cur=q.front(),q.pop();

        for(k=0;k<n;k++)

        {

            nxt=next[cur][k];

            if(nxt)

            {

                if(!cur)    fail[nxt]=0;

                else

                {

                    for(tmp=fail[cur];tmp&&!next[tmp][k];tmp=fail[tmp]);

                    fail[nxt]=next[tmp][k];

                }

                if(flag[fail[nxt]]) flag[nxt]=1;

                q.push(nxt);

            }

            else    next[cur][k]=next[fail[cur]][k];

        }

    }

}

void ADD(int *a,int *b)

{

    int i,c=0;

    for(i=0;i<100;i++)

    {

        a[i]+=b[i]+c;

        c=a[i]/10;

        a[i]%=10;

    }

}

void solve()

{

    memset(dp,0,sizeof(dp));

    dp[0][0][0]=1;



    for(int step=1;step<=L;step++)

    {

        for(int pre=0;pre<node;pre++)

        {

            if(flag[pre])   continue;

            for(int k=0;k<n;k++)

            {

                int cur=next[pre][k];

                if(flag[cur])   continue;

                ADD(dp[step][cur],dp[step-1][pre]);

            }

        }

    }



    int ans[100],i;

    memset(ans,0,sizeof(ans));

    for(i=0;i<node;i++) if(!flag[i])    ADD(ans,dp[L][i]);



    for(i=99;i>=0 && ans[i]==0;i--);

    if(i<0) puts("0");

    else

    {

        for(;i>=0;i--)  printf("%d",ans[i]);

        puts("");

    }

}

int main()

{

    char s[51];

    while(~scanf("%d%d%d",&n,&L,&m))

    {

        getchar();

        gets(ch);

        qsort(ch,strlen(ch),sizeof(char),cmp);



        init();

        for(int i=0;i<m;i++)

        {

            gets(s);

            insert(s);

        }

        build_ac();

        solve();

    }

    return 0;

}

 

你可能感兴趣的:(poj)