题意:给出一个字符集V和P个模式串(长度小于10),问由这个字符集中字符组成的长度为N的且不包含任意一个模式串的字符串有多少个?(字符集大小,N<=50, P <= 10) 。
思路:先将P个模式串建立AC自动机,标记好危险节点(flag数组)。然后动归来求:dp[i][j]表示长度为i且最后在节点j的字符串个数(节点j必为安全节点),初始dp[0][1] = 1, 其他dp[i][j] = 0。由dp[i][j] 可以导出,每个由j可以到达的安全节点son[j],执行:dp[i+1][son[j]] += dp[i][j]。因为从根走i步到达节点j有n种走法,那么走i+1步到达son[j]的走法就要加n。最终的答案为∑{dp[N][j] | j是安全节点}。
最后的数量很大,需要用数组存数。
需要注意:如果是用指针来存储,那么son[j]不仅指从j通过一条字母边直接到达的son[j], 也可以是通过若干前缀指针后再通过一个字母边到达son[j],(即son[j]并不一定是 j 的子节点)。而用数组存储恰恰避免了这一点。
#include <cstdio> #include <cstring> #include <algorithm> #include <map> #include <queue> #include <cstdlib> using namespace std; #define INF 0x3fffffff #define N 505 #define M 55 int n,m,p; int t[N][M],fail[N],top; bool flag[N]; char word[M],s[M]; struct dp{ int num[100]; bool has; }dp[M][N]; int res[100]; map<char, int> hh; queue<int> q; void init(){ int i; memset(t, -1, sizeof(t)); memset(fail, 0, sizeof(fail)); memset(flag, false, sizeof(flag)); for(i = 0;i<n;i++) t[0][i] = 1; top = 1; hh.clear(); for(i = 0;word[i];i++)//字母表和0...n-1的对应 hh[word[i]] = i; } void insert(char* s){ int i,r = 1; for(i = 0;s[i];i++){ if(t[r][hh[s[i]]] == -1) t[r][hh[s[i]]] = (++top); r = t[r][hh[s[i]]]; } flag[r] = true; } void buildDFA(){ int i,now; q.push(1); while(!q.empty()){ now = q.front(); q.pop(); for(i = 0;i<n;i++){ if(t[now][i] == -1) t[now][i] = t[fail[now]][i]; else{ fail[t[now][i]] = t[fail[now]][i]; q.push(t[now][i]); if(flag[t[fail[now]][i]])//危险节点建立好 flag[t[now][i]] = true; } } } } void add(int* a,int* b){//大数相加,把b加到a int i,j; for(i = j = 0;i<100;i++){ a[i] += b[i]+j; j = a[i]/10; a[i] %= 10; } } int main(){ int i,j,k; while(scanf("%d %d %d\n",&n,&m,&p)!=EOF){ gets(word); init(); for(i = 1;i<=p;i++){ gets(s); insert(s); } buildDFA(); for(i = 0;i<=m;i++) for(j = 1;j<=top;j++){ memset(dp[i][j].num, 0, sizeof(dp[i][j].num)); dp[i][j].has = false; } dp[0][1].num[0] = 1; dp[0][1].has = true; for(i = 0;i<m;i++) for(j = 1;j<=top;j++) for(k = 0;k<n;k++) if(dp[i][j].has && !flag[t[j][k]]){ add(dp[i+1][t[j][k]].num , dp[i][j].num); dp[i+1][t[j][k]].has = true; } memset(res, 0, sizeof(res)); for(i = 1;i<=top;i++) add(res,dp[m][i].num); for(i = 99;i>=0&&!res[i];i--); if(i==-1) putchar('0'); for(;i>=0;i--) printf("%d",res[i]); putchar('\n'); } return 0; }