2012杭州网络赛的一道题,后缀数组后缀自动机都行吧。
题目大意:给一个字符串S和一系列字符串T1~Tn,问在S中有多少个不同子串满足它不是T1~Tn中任意一个字符串的子串。
思路:我们先构造S的后缀自动机,然后将每一个Ti在S的SAM上做匹配,类似于LCS,在S中的每一个状态记录一个变量deep,表示T1~Tn,在该状态能匹配的最大长度是多少,将每一个Ti匹配完之后,我们将S的SAM做拓扑排序,自底向上更新每个状态的deep,同时计算在该状态上有多少个子串满足题目要求。具体步骤如下:
1:对于当前状态,设为p,设p的par为q,则更新q->deep为q->deep和p->deep中的较大值。
2:若p->deep<p->val,则表示在状态p中,长度为p->deep+1~p->val的子串不是T1~Tn中任意字符串的子串,所以答案加上p->val-p->deep。否则表示状态p中所有字串均不满足要求,跳过即可。
(注意若p->deep==0,表示状态p中所有的子串均满足题目要求,但是答案不是加上p->val-0,而是加上 p->val-p->par->val,这表示状态p中的字符串个数,所以对于p->deep==0要特殊处理)
最后输出答案即可。
代码如下:
#include <iostream> #include <string.h> #include <stdio.h> #define maxn 200010 #define Smaxn 26 using namespace std; struct node { node *par,*go[Smaxn]; int deep; int val; }*root,*tail,que[maxn],*top[maxn]; int tot; char str[maxn>>1]; void add(int c,int l) { node *p=tail,*np=&que[tot++]; np->val=l; while(p&&p->go[c]==NULL) p->go[c]=np,p=p->par; if(p==NULL) np->par=root; else { node *q=p->go[c]; if(p->val+1==q->val) np->par=q; else { node *nq=&que[tot++]; *nq=*q; nq->val=p->val+1; np->par=q->par=nq; while(p&&p->go[c]==q) p->go[c]=nq,p=p->par; } } tail=np; } int c[maxn],len; void init(int n) { int i; for(i=0;i<=n;i++) { que[i].deep=que[i].val=0; que[i].par=NULL; memset(que[i].go,0,sizeof(que[i].go)); } tot=0; len=1; root=tail=&que[tot++]; } int max(int a,int b) { return a>b?a:b; } void solve(int q) { memset(c,0,sizeof(c)); int i; for(i=0;i<tot;i++) c[que[i].val]++; for(i=1;i<len;i++) c[i]+=c[i-1]; for(i=0;i<tot;i++) top[--c[que[i].val]]=&que[i]; while(q--) { node *p=root; scanf("%s",str); int l=strlen(str),tmp=0; for(i=0;i<l;i++) { int x=str[i]-'a'; if(p->go[x]) { tmp++; p=p->go[x]; p->deep=max(p->deep,tmp); } else { while(p&&p->go[x]==0) { p=p->par; } if(p) { tmp=p->val+1; p=p->go[x]; p->deep=max(tmp,p->deep); } else { tmp=0; p=root; } } } } long long sum=0; for(i=tot-1;i>0;i--) { node *q=top[i]; if(q->deep>0) { q->par->deep=max(q->par->deep,q->deep); if(q->deep<q->val) { sum+=q->val-q->deep; } } else { sum+=q->val-q->par->val; } } printf("%I64d\n",sum); } int main() { freopen("dd.txt","r",stdin); int ncase,time=0; scanf("%d",&ncase); while(ncase--) { printf("Case %d: ",++time); int n; scanf("%d",&n); scanf("%s",str); int i,l=strlen(str); init(l*2); for(i=0;i<l;i++) add(str[i]-'a',len++); solve(n); } return 0; }