题意:给一个字符集,构成长N的字符串,不含M个危险串的任一个,求合法字符串总数。
AC自动机+DP。。
dp[i][j]=在字符串长i,在节点j的总数。
dp[i][j]=sigma(dp[i-1][k]*move[k][j]),move[k][j]表示节点k到节点j路径数。
Time Limit: 5000MS | Memory Limit: 10000K | |
Total Submissions: 7862 | Accepted: 2132 |
Description
Input
Output
Sample Input
2 3 1 ab bb
Sample Output
5
Source
#include<iostream> #include<cstring> #include<cstdio> #include<queue> using namespace std; const int N=1022; int ch[N][128],fail[N]; bool end[N]; int root,L; int Char; char charset[N]; int newnode() { memset(ch[L],-1,sizeof ch[L]);end[L++]=0; return L-1; } void init(char s[]) { strcpy(charset,s); Char=strlen(s); L=0; root=newnode(); } int idx(char a) { for(int i=0;i<Char;i++) if(charset[i]==a)return i; } void insert(char s[]) { int n=strlen(s),u=root; for(int i=0;i<n;i++) { int& tmp=ch[u][idx(s[i])]; if(tmp==-1) tmp=newnode(); u=tmp; } end[u]=1; } void BUILD() { queue<int> q; for(int i=0;i<Char;i++) { int& tmp=ch[root][i]; if(tmp==-1) tmp=root; else { fail[tmp]=root; q.push(tmp); } } while(!q.empty()) { int u=q.front(); q.pop(); end[u]|=end[fail[u]]; for(int i=0;i<Char;i++) { int& tmp=ch[u][i]; if(tmp==-1) tmp=ch[fail[u]][i]; else { fail[tmp]=ch[fail[u]][i]; q.push(tmp); } } } } int move[N][N]; void getmove() { memset(move,0,sizeof move); for(int i=0;i<L;i++) for(int j=0;j<Char;j++) move[i][ch[i][j]]+=!end[ch[i][j]]; } struct Big { int a[121]; int len; Big() { len=1;memset(a,0,sizeof a); } void add(Big b) { len=max(len,b.len); for(int i=0;i<len;i++) { a[i+1]+=(a[i]+b.a[i])/10; a[i]=(a[i]+b.a[i])%10; } while(a[len]) len++; } void print() { for(int i=len-1;i>=0;i--) printf("%d",a[i]); putchar(10); } }; Big dp[2][N]; char s[N]; int main() { int tt,n,m; while(cin>>tt>>n>>m) { scanf("%s",s); init(s); for(int i=0;i<m;i++) { scanf("%s",s);insert(s); } BUILD(); getmove(); int pos=0; for(int i=0;i<L;i++) dp[0][i]=Big(); dp[0][0].a[0]=1; dp[0][0].len=1; for(int i=0;i<n;i++,pos=!pos) { for(int j=0;j<L;j++) { dp[!pos][j]=Big(); for(int k=0;k<L;k++) for(int t=0;t<move[k][j];t++) dp[!pos][j].add(dp[pos][k]); } } Big res=Big(); for(int i=0;i<L;i++) res.add(dp[pos][i]); res.print(); } }